Skip to content

Commit

Permalink
support inf2 neuronx transformer continuous batching (#2803)
Browse files Browse the repository at this point in the history
* fmt

* fmt

* fmt

* add space

* fmt

* fmt

* fmt

* fmt

* fix regression test

* check key result

* fmt

* update folder

* fmt

* update key name

* add orjson

* update streamer

* add key text for streamer iterator

* update test_hf_batch_streamer output

* integrate split checkpoint in handler

* fmt

* fmt

* fmt

* fmt

* fmt

* fmt

* update notebook

* fmt

* add handler utils

* fix typo

* fmt

* fmt

* fmt

* fmt

* fmt

* Fix lint

* fix typo in notebook example

* enable authentication

* fmt

* fmt

* fmt

* update readme

* fix lint

* fmt

* update test data

* update test

* update test

* replace os.path with pathlib

* update test

* fmt
  • Loading branch information
lxning authored Feb 27, 2024
1 parent eaacf9d commit 2818784
Show file tree
Hide file tree
Showing 17 changed files with 937 additions and 116 deletions.
113 changes: 3 additions & 110 deletions examples/large_models/inferentia2/llama2/Readme.md
Original file line number Diff line number Diff line change
@@ -1,113 +1,6 @@
# Large model inference on Inferentia2

This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support.
This folder briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe's features:

Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference.

**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example.

The batch size and micro batch size configurations are present in [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay.
The batch size is chosen to be a relatively large value, say 16 since micro batching enables running the preprocess(tokenization) and inference steps in parallel on the micro batches. The micro batch size is the batch size used for the Inf2 model compilation.
Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4.

This example also demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURONX_DUMP_TO` environment variables in the custom handler.
When the model is loaded for the first time, the model is compiled for the configured micro batch size and the compilation artifacts are saved to the neuronx cache.
On subsequent model load, the compilation artifacts in the neuronx cache serves as `Ahead of Time(AOT)` compilation artifacts and significantly reduces the model load time.
For convenience, the compiled model artifacts for this example are made available on the Torchserve model zoo: `s3://torchserve/mar_files/llama-2-13b-neuronx-b4`\
Instructions on how to use the AOT compiled model artifacts is shown below.

### Step 1: Inf2 instance

Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlarge`), ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed.
DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher.

**Note**: The `inf2.24xlarge` instance consists of 6 neuron chips with 2 neuron cores each. The total accelerator memory is 192GB.
Based on the configuration used in [model-config.yaml](model-config.yaml), with `tp_degree` set to 6, 3 of the 6 neuron chips are used, i.e 6 neuron cores.
On loading the model, the accelerator memory consumed is 38.1GB (12.7GB per chip).

### Step 2: Package Installations

Follow the steps below to complete package installations

```bash
sudo apt-get update
sudo apt-get upgrade

# Activate Python venv
source /opt/aws_neuron_venv_pytorch/bin/activate

# Clone Torchserve git repository
git clone https://github.com/pytorch/serve.git
cd serve

# Install dependencies
python ts_scripts/install_dependencies.py --neuronx --environment=dev

# Install torchserve and torch-model-archiver
python ts_scripts/install_from_src.py

# Navigate to `examples/large_models/inferentia2/llama2` directory
cd examples/large_models/inferentia2/llama2/

# Install additional necessary packages
python -m pip install -r requirements.txt
```

### Step 3: Save the model artifacts compatible with `transformers-neuronx`
In order to use the pre-compiled model artifacts, copy them from the model zoo using the command shown below and skip to **Step 5**
```bash
aws s3 cp s3://torchserve/mar_files/llama-2-13b-neuronx-b4/ llama-2-13b --recursive
```

In order to download and compile the Llama2 model from scratch for support on Inf2:\
Request access to the Llama2 model\
https://huggingface.co/meta-llama/Llama-2-13b-hf

Login to Huggingface
```bash
huggingface-cli login
```

Run the `inf2_save_split_checkpoints.py` script
```bash
python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split'
```


### Step 4: Package model artifacts

```bash
torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive
mv llama-2-13b-split llama-2-13b
```

### Step 5: Add the model artifacts to model store

```bash
mkdir model_store
mv llama-2-13b model_store
```

### Step 6: Start torchserve

```bash
torchserve --ncs --start --model-store model_store --ts-config config.properties
```

### Step 7: Register model

```bash
curl -X POST "http://localhost:8081/models?url=llama-2-13b"
```

### Step 8: Run inference

```bash
python test_stream_response.py
```

### Step 9: Stop torchserve

```bash
torchserve --stop
```
* demo1: [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support in folder [streamer](streamer).
* demo2: continuous batching support in folder [continuous_batching](continuous_batching)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Demo2: Llama-2 Using TorchServe continuous batching on inf2

This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS transformers-neuronx continuous batching](https://aws.amazon.com/ec2/instance-types/inf2/).

This example can also be extended to support Mistral without code changes. Customers only set the following items in model-config.yaml. For example:
* model_path: "model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939"
* model_checkpoint_dir: "llama-2-70b-split"
* model_module_prefix: "transformers_neuronx"
* model_class_name: "llama.model.LlamaForSampling"
* tokenizer_class_name: "transformers.LlamaTokenizer"

| Model | Model Class |
| :--- | :----: |
| llama | lama.model.LlamaForSampling |
| mistral | mistral.model.MistralForSampling |


The batch size in [model-config.yaml](model-config.yaml) indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. It is the batch size used for the Inf2 model compilation.
Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4.

`inf2-llama-2-continuous-batching.ipynb` is the notebook example.
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"## TorchServe Continuous Batching Serve Llama-2-70B on Inferentia-2\n",
"This notebook demonstrates TorchServe continuous batching serving Llama-2-70b on Inferentia-2 `inf2.48xlarge` with DLAMI: Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20231226"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Installation\n",
"Note: This section can be skipped once Neuron DLC 2.16 with TorchServe latest version is released."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Install Python venv\n",
"!sudo apt-get install -y python3.9-venv g++\n",
"\n",
"# Create Python venv\n",
"!python3.9 -m venv aws_neuron_venv_pytorch\n",
"\n",
"# Activate Python venv\n",
"!source aws_neuron_venv_pytorch/bin/activate\n",
"!python -m pip install -U pip\n",
"\n",
"# Clone Torchserve git repository\n",
"!git clone https://github.com/pytorch/serve.git\n",
"\n",
"# Install dependencies, now all commands run under serve dir\n",
"!cd serve\n",
"!git checkout feat/inf2_cb\n",
"!python ts_scripts/install_dependencies.py --neuronx --environment=dev\n",
"\n",
"# Install torchserve and torch-model-archiver\n",
"python ts_scripts/install_from_src.py"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Create model artifacts\n",
"\n",
"Note: run `mv model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939/model.safetensors.index.json model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939/model.safetensors.index.json.bkp`\n",
" if neuron sdk does not support safetensors"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# login in Hugginface hub\n",
"!huggingface-cli login --token $HUGGINGFACE_TOKEN\n",
"!python examples/large_models/utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-13b-hf --use_auth_token True\n",
"\n",
"# Create TorchServe model artifacts\n",
"!torch-model-archiver --model-name llama-2-70b --version 1.0 --handler ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py -r examples/large_models/inferentia2/llama2/requirements.txt --config-file examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml --archive-format no-archive\n",
"\n",
"!mkdir -p model_store\n",
"!mv llama-2-70b model_store\n",
"!mv model model_store/llama-2-70b"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Start TorchServe"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"torchserve --ncs --start --model-store model_store --models llama-2-70b --ts-config examples/large_models/inferentia2/llama2/config.properties"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Run inference"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Run single inference request\n",
"!python examples/large_models/utils/test_llm_streaming_response.py -m llama-2-70b -o 50 -t 2 -n 4 --prompt-text \"Today the weather is really nice and I am planning on \" --prompt-randomize"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 0
batchSize: 8
responseTimeout: 10800
jobQueueSize: 500
continuousBatching: true

handler:
model_path: "model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939"
model_checkpoint_dir: "llama-2-70b-split"
model_module_prefix: "transformers_neuronx"
model_class_name: "llama.model.LlamaForSampling"
tokenizer_class_name: "transformers.LlamaTokenizer"
amp: "bf16"
tp_degree: 24
max_length: 256
max_new_tokens: 50
batch_size: 8
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sentencepiece
Loading

0 comments on commit 2818784

Please sign in to comment.