From 28187841c4e8dc71ffbf74b386976816dc7d18cd Mon Sep 17 00:00:00 2001 From: lxning <23464292+lxning@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:06:39 -0800 Subject: [PATCH] support inf2 neuronx transformer continuous batching (#2803) * fmt * fmt * fmt * add space * fmt * fmt * fmt * fmt * fix regression test * check key result * fmt * update folder * fmt * update key name * add orjson * update streamer * add key text for streamer iterator * update test_hf_batch_streamer output * integrate split checkpoint in handler * fmt * fmt * fmt * fmt * fmt * fmt * update notebook * fmt * add handler utils * fix typo * fmt * fmt * fmt * fmt * fmt * Fix lint * fix typo in notebook example * enable authentication * fmt * fmt * fmt * update readme * fix lint * fmt * update test data * update test * update test * replace os.path with pathlib * update test * fmt --- .../large_models/inferentia2/llama2/Readme.md | 113 +---- .../llama2/continuous_batching/Readme.md | 21 + .../inf2-llama-2-continuous-batching.ipynb | 148 +++++++ .../continuous_batching/model-config.yaml | 19 + .../continuous_batching/requirements.txt | 1 + .../inferentia2/llama2/streamer/Readme.md | 113 +++++ .../llama2/{ => streamer}/inf2_handler.py | 0 .../llama2/{ => streamer}/model-config.yaml | 0 .../llama2/test_stream_response.py | 4 +- examples/large_models/utils/Download_model.py | 9 +- .../utils/test_llm_streaming_response.py | 140 ++++++ ts/handler_utils/hf_batch_streamer.py | 2 +- ts/handler_utils/utils.py | 29 ++ ts/tests/unit_tests/test_handler_utils.py | 32 ++ ts/tests/unit_tests/test_hf_batch_streamer.py | 4 +- ...ase_neuronx_continuous_batching_handler.py | 403 ++++++++++++++++++ ts_scripts/spellcheck_conf/wordlist.txt | 15 +- 17 files changed, 937 insertions(+), 116 deletions(-) create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/Readme.md create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/requirements.txt create mode 100644 examples/large_models/inferentia2/llama2/streamer/Readme.md rename examples/large_models/inferentia2/llama2/{ => streamer}/inf2_handler.py (100%) rename examples/large_models/inferentia2/llama2/{ => streamer}/model-config.yaml (100%) create mode 100644 examples/large_models/utils/test_llm_streaming_response.py create mode 100644 ts/tests/unit_tests/test_handler_utils.py create mode 100644 ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py diff --git a/examples/large_models/inferentia2/llama2/Readme.md b/examples/large_models/inferentia2/llama2/Readme.md index e01aaced62..d737117acb 100644 --- a/examples/large_models/inferentia2/llama2/Readme.md +++ b/examples/large_models/inferentia2/llama2/Readme.md @@ -1,113 +1,6 @@ # Large model inference on Inferentia2 -This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. +This folder briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe's features: -Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. - -**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example. - -The batch size and micro batch size configurations are present in [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. -The batch size is chosen to be a relatively large value, say 16 since micro batching enables running the preprocess(tokenization) and inference steps in parallel on the micro batches. The micro batch size is the batch size used for the Inf2 model compilation. -Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. - -This example also demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURONX_DUMP_TO` environment variables in the custom handler. -When the model is loaded for the first time, the model is compiled for the configured micro batch size and the compilation artifacts are saved to the neuronx cache. -On subsequent model load, the compilation artifacts in the neuronx cache serves as `Ahead of Time(AOT)` compilation artifacts and significantly reduces the model load time. -For convenience, the compiled model artifacts for this example are made available on the Torchserve model zoo: `s3://torchserve/mar_files/llama-2-13b-neuronx-b4`\ -Instructions on how to use the AOT compiled model artifacts is shown below. - -### Step 1: Inf2 instance - -Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlarge`), ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed. -DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher. - -**Note**: The `inf2.24xlarge` instance consists of 6 neuron chips with 2 neuron cores each. The total accelerator memory is 192GB. -Based on the configuration used in [model-config.yaml](model-config.yaml), with `tp_degree` set to 6, 3 of the 6 neuron chips are used, i.e 6 neuron cores. -On loading the model, the accelerator memory consumed is 38.1GB (12.7GB per chip). - -### Step 2: Package Installations - -Follow the steps below to complete package installations - -```bash -sudo apt-get update -sudo apt-get upgrade - -# Activate Python venv -source /opt/aws_neuron_venv_pytorch/bin/activate - -# Clone Torchserve git repository -git clone https://github.com/pytorch/serve.git -cd serve - -# Install dependencies -python ts_scripts/install_dependencies.py --neuronx --environment=dev - -# Install torchserve and torch-model-archiver -python ts_scripts/install_from_src.py - -# Navigate to `examples/large_models/inferentia2/llama2` directory -cd examples/large_models/inferentia2/llama2/ - -# Install additional necessary packages -python -m pip install -r requirements.txt -``` - -### Step 3: Save the model artifacts compatible with `transformers-neuronx` -In order to use the pre-compiled model artifacts, copy them from the model zoo using the command shown below and skip to **Step 5** -```bash -aws s3 cp s3://torchserve/mar_files/llama-2-13b-neuronx-b4/ llama-2-13b --recursive -``` - -In order to download and compile the Llama2 model from scratch for support on Inf2:\ -Request access to the Llama2 model\ -https://huggingface.co/meta-llama/Llama-2-13b-hf - -Login to Huggingface -```bash -huggingface-cli login -``` - -Run the `inf2_save_split_checkpoints.py` script -```bash -python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split' -``` - - -### Step 4: Package model artifacts - -```bash -torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive -mv llama-2-13b-split llama-2-13b -``` - -### Step 5: Add the model artifacts to model store - -```bash -mkdir model_store -mv llama-2-13b model_store -``` - -### Step 6: Start torchserve - -```bash -torchserve --ncs --start --model-store model_store --ts-config config.properties -``` - -### Step 7: Register model - -```bash -curl -X POST "http://localhost:8081/models?url=llama-2-13b" -``` - -### Step 8: Run inference - -```bash -python test_stream_response.py -``` - -### Step 9: Stop torchserve - -```bash -torchserve --stop -``` +* demo1: [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support in folder [streamer](streamer). +* demo2: continuous batching support in folder [continuous_batching](continuous_batching) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md new file mode 100644 index 0000000000..8b935b2288 --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md @@ -0,0 +1,21 @@ +# Demo2: Llama-2 Using TorchServe continuous batching on inf2 + +This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS transformers-neuronx continuous batching](https://aws.amazon.com/ec2/instance-types/inf2/). + +This example can also be extended to support Mistral without code changes. Customers only set the following items in model-config.yaml. For example: +* model_path: "model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939" +* model_checkpoint_dir: "llama-2-70b-split" +* model_module_prefix: "transformers_neuronx" +* model_class_name: "llama.model.LlamaForSampling" +* tokenizer_class_name: "transformers.LlamaTokenizer" + +| Model | Model Class | +| :--- | :----: | +| llama | lama.model.LlamaForSampling | +| mistral | mistral.model.MistralForSampling | + + +The batch size in [model-config.yaml](model-config.yaml) indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. It is the batch size used for the Inf2 model compilation. +Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. + +`inf2-llama-2-continuous-batching.ipynb` is the notebook example. diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb new file mode 100644 index 0000000000..e6897cca85 --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -0,0 +1,148 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## TorchServe Continuous Batching Serve Llama-2-70B on Inferentia-2\n", + "This notebook demonstrates TorchServe continuous batching serving Llama-2-70b on Inferentia-2 `inf2.48xlarge` with DLAMI: Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20231226" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Installation\n", + "Note: This section can be skipped once Neuron DLC 2.16 with TorchServe latest version is released." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Install Python venv\n", + "!sudo apt-get install -y python3.9-venv g++\n", + "\n", + "# Create Python venv\n", + "!python3.9 -m venv aws_neuron_venv_pytorch\n", + "\n", + "# Activate Python venv\n", + "!source aws_neuron_venv_pytorch/bin/activate\n", + "!python -m pip install -U pip\n", + "\n", + "# Clone Torchserve git repository\n", + "!git clone https://github.com/pytorch/serve.git\n", + "\n", + "# Install dependencies, now all commands run under serve dir\n", + "!cd serve\n", + "!git checkout feat/inf2_cb\n", + "!python ts_scripts/install_dependencies.py --neuronx --environment=dev\n", + "\n", + "# Install torchserve and torch-model-archiver\n", + "python ts_scripts/install_from_src.py" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Create model artifacts\n", + "\n", + "Note: run `mv model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939/model.safetensors.index.json model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939/model.safetensors.index.json.bkp`\n", + " if neuron sdk does not support safetensors" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# login in Hugginface hub\n", + "!huggingface-cli login --token $HUGGINGFACE_TOKEN\n", + "!python examples/large_models/utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-13b-hf --use_auth_token True\n", + "\n", + "# Create TorchServe model artifacts\n", + "!torch-model-archiver --model-name llama-2-70b --version 1.0 --handler ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py -r examples/large_models/inferentia2/llama2/requirements.txt --config-file examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml --archive-format no-archive\n", + "\n", + "!mkdir -p model_store\n", + "!mv llama-2-70b model_store\n", + "!mv model model_store/llama-2-70b" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Start TorchServe" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "torchserve --ncs --start --model-store model_store --models llama-2-70b --ts-config examples/large_models/inferentia2/llama2/config.properties" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Run inference" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Run single inference request\n", + "!python examples/large_models/utils/test_llm_streaming_response.py -m llama-2-70b -o 50 -t 2 -n 4 --prompt-text \"Today the weather is really nice and I am planning on \" --prompt-randomize" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml new file mode 100644 index 0000000000..2a69d4de52 --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml @@ -0,0 +1,19 @@ +minWorkers: 1 +maxWorkers: 1 +maxBatchDelay: 0 +batchSize: 8 +responseTimeout: 10800 +jobQueueSize: 500 +continuousBatching: true + +handler: + model_path: "model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939" + model_checkpoint_dir: "llama-2-70b-split" + model_module_prefix: "transformers_neuronx" + model_class_name: "llama.model.LlamaForSampling" + tokenizer_class_name: "transformers.LlamaTokenizer" + amp: "bf16" + tp_degree: 24 + max_length: 256 + max_new_tokens: 50 + batch_size: 8 diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/requirements.txt b/examples/large_models/inferentia2/llama2/continuous_batching/requirements.txt new file mode 100644 index 0000000000..080a0d3acc --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/requirements.txt @@ -0,0 +1 @@ +sentencepiece diff --git a/examples/large_models/inferentia2/llama2/streamer/Readme.md b/examples/large_models/inferentia2/llama2/streamer/Readme.md new file mode 100644 index 0000000000..6490420c0f --- /dev/null +++ b/examples/large_models/inferentia2/llama2/streamer/Readme.md @@ -0,0 +1,113 @@ +# Demo1: Llama-2 Using TorchServe micro-batching and Streamer on inf2 + +This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. + +Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. + +**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example. + +The batch size and micro batch size configurations are present in [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. +The batch size is chosen to be a relatively large value, say 16 since micro batching enables running the preprocess(tokenization) and inference steps in parallel on the micro batches. The micro batch size is the batch size used for the Inf2 model compilation. +Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. + +This example also demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURONX_DUMP_TO` environment variables in the custom handler. +When the model is loaded for the first time, the model is compiled for the configured micro batch size and the compilation artifacts are saved to the neuronx cache. +On subsequent model load, the compilation artifacts in the neuronx cache serves as `Ahead of Time(AOT)` compilation artifacts and significantly reduces the model load time. +For convenience, the compiled model artifacts for this example are made available on the Torchserve model zoo: `s3://torchserve/mar_files/llama-2-13b-neuronx-b4`\ +Instructions on how to use the AOT compiled model artifacts is shown below. + +### Step 1: Inf2 instance + +Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlarge`), ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed. +DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher. + +**Note**: The `inf2.24xlarge` instance consists of 6 neuron chips with 2 neuron cores each. The total accelerator memory is 192GB. +Based on the configuration used in [model-config.yaml](model-config.yaml), with `tp_degree` set to 6, 3 of the 6 neuron chips are used, i.e 6 neuron cores. +On loading the model, the accelerator memory consumed is 38.1GB (12.7GB per chip). + +### Step 2: Package Installations + +Follow the steps below to complete package installations + +```bash +sudo apt-get update +sudo apt-get upgrade + +# Activate Python venv +source /opt/aws_neuron_venv_pytorch/bin/activate + +# Clone Torchserve git repository +git clone https://github.com/pytorch/serve.git +cd serve + +# Install dependencies +python ts_scripts/install_dependencies.py --neuronx --environment=dev + +# Install torchserve and torch-model-archiver +python ts_scripts/install_from_src.py + +# Navigate to `examples/large_models/inferentia2/llama2` directory +cd examples/large_models/inferentia2/llama2/ + +# Install additional necessary packages +python -m pip install -r requirements.txt +``` + +### Step 3: Save the model artifacts compatible with `transformers-neuronx` +In order to use the pre-compiled model artifacts, copy them from the model zoo using the command shown below and skip to **Step 5** +```bash +aws s3 cp s3://torchserve/mar_files/llama-2-13b-neuronx-b4/ llama-2-13b --recursive +``` + +In order to download and compile the Llama2 model from scratch for support on Inf2:\ +Request access to the Llama2 model\ +https://huggingface.co/meta-llama/Llama-2-13b-hf + +Login to Huggingface +```bash +huggingface-cli login +``` + +Run the `inf2_save_split_checkpoints.py` script +```bash +python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split' +``` + + +### Step 4: Package model artifacts + +```bash +torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive +mv llama-2-13b-split llama-2-13b +``` + +### Step 5: Add the model artifacts to model store + +```bash +mkdir model_store +mv llama-2-13b model_store +``` + +### Step 6: Start torchserve + +```bash +torchserve --ncs --start --model-store model_store --ts-config config.properties +``` + +### Step 7: Register model + +```bash +curl -X POST "http://localhost:8081/models?url=llama-2-13b" +``` + +### Step 8: Run inference + +```bash +python test_stream_response.py +``` + +### Step 9: Stop torchserve + +```bash +torchserve --stop +``` diff --git a/examples/large_models/inferentia2/llama2/inf2_handler.py b/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py similarity index 100% rename from examples/large_models/inferentia2/llama2/inf2_handler.py rename to examples/large_models/inferentia2/llama2/streamer/inf2_handler.py diff --git a/examples/large_models/inferentia2/llama2/model-config.yaml b/examples/large_models/inferentia2/llama2/streamer/model-config.yaml similarity index 100% rename from examples/large_models/inferentia2/llama2/model-config.yaml rename to examples/large_models/inferentia2/llama2/streamer/model-config.yaml diff --git a/examples/large_models/inferentia2/llama2/test_stream_response.py b/examples/large_models/inferentia2/llama2/test_stream_response.py index 2c205dd3de..018841320c 100644 --- a/examples/large_models/inferentia2/llama2/test_stream_response.py +++ b/examples/large_models/inferentia2/llama2/test_stream_response.py @@ -1,3 +1,4 @@ +import orjson import requests response = requests.post( @@ -9,6 +10,7 @@ for chunk in response.iter_content(chunk_size=None): if chunk: data = chunk.decode("utf-8") - print(data, end="", flush=True) + data = orjson.loads(data) + print(data["text"], end=" ", flush=True) print("") diff --git a/examples/large_models/utils/Download_model.py b/examples/large_models/utils/Download_model.py index 5302a01ba6..2d41a6ddcc 100644 --- a/examples/large_models/utils/Download_model.py +++ b/examples/large_models/utils/Download_model.py @@ -39,6 +39,13 @@ def hf_model(model_str): parser.add_argument( "--model_name", "-m", type=hf_model, required=True, help="HuggingFace model name" ) +parser.add_argument( + "--use_auth_token", + "-t", + type=bool, + default=False, + help="Use HF authentication token", +) parser.add_argument("--revision", "-r", type=str, default="main", help="Revision") args = parser.parse_args() # Only download pytorch checkpoint files @@ -49,6 +56,6 @@ def hf_model(model_str): revision=args.revision, allow_patterns=allow_patterns, cache_dir=args.model_path, - use_auth_token=False, + use_auth_token=args.use_auth_token, ) print(f"Files for '{args.model_name}' is downloaded to '{snapshot_path}'") diff --git a/examples/large_models/utils/test_llm_streaming_response.py b/examples/large_models/utils/test_llm_streaming_response.py new file mode 100644 index 0000000000..b8b7f10ec1 --- /dev/null +++ b/examples/large_models/utils/test_llm_streaming_response.py @@ -0,0 +1,140 @@ +import argparse +import random +import threading +from queue import Queue + +import orjson +import requests + +max_prompt_random_tokens = 20 + + +class Predictor(threading.Thread): + def __init__(self, args, queue): + super().__init__() + self.args = args + self.queue = queue + + def run(self): + for _ in range(self.args.num_requests_per_thread): + self._predict() + + def _predict(self): + payload = self._format_payload() + with requests.post(self._get_url(), json=payload, stream=True) as response: + combined_text = "" + for chunk in response.iter_content(chunk_size=None): + if chunk: + data = orjson.loads(chunk) + combined_text += data["text"] + self.queue.put_nowait(f"payload={payload}\n, output={combined_text}\n") + + def _get_url(self): + return f"http://localhost:8080/predictions/{self.args.model}" + + def _format_payload(self): + prompt = _load_curl_like_data(self.args.prompt_text) + prompt_list = prompt.split(" ") + rp = len(prompt_list) + rt = self.args.max_tokens + if self.args.prompt_randomize: + rp = random.randint(0, max_prompt_random_tokens) + rt = rp + self.args.max_tokens + for _ in range(rp): + prompt_list.insert(0, chr(ord("a") + random.randint(0, 25))) + cur_prompt = " ".join(prompt_list) + return { + "prompt": cur_prompt, + "max_new_tokens": rt, + } + + +def _load_curl_like_data(text): + """ + Either use the passed string or load from a file if the string is `@filename` + """ + if text.startswith("@"): + try: + with open(text[1:], "r") as f: + return f.read() + except Exception as e: + raise ValueError(f"Failed to read file {text[1:]}") from e + else: + return text + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model", + required=True, + type=str, + help="The model to use for generating text.", + ) + parser.add_argument( + "--prompt-text", + required=True, + type=str, + help="Prompt text to use instead of generating one. It can be a file reference starting with an ampersand, e.g. `@prompt.txt`", + ) + parser.add_argument( + "--prompt-randomize", + action=argparse.BooleanOptionalAction, + default=False, + help="Include a few random numbers in the generated prompt to avoid caching", + ) + parser.add_argument( + "-o", + "--max-tokens", + type=int, + default=64, + help="Max number of tokens to generate.", + ) + parser.add_argument( + "-t", + "--num-threads", + type=int, + default=1, + help="Enable the number of threads to execute prediction", + ) + parser.add_argument( + "-n", + "--num-requests-per-thread", + type=int, + default=1, + help="Execute the number of prediction in each thread", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + if len(args.model) == 0: + print("model argument can not be empty.") + exit(1) + + if len(args.prompt_text) == 0: + print("prompt argument can not be empty.") + exit(1) + + queue = Queue() + predictors = [] + for i in range(args.num_threads): + predictor = Predictor(args, queue) + predictors.append(predictor) + predictor.start() + + for predictor in predictors: + predictor.join() + + print("Tasks are completed") + + while not queue.empty(): + print(queue.get()) + + +if __name__ == "__main__": + main() diff --git a/ts/handler_utils/hf_batch_streamer.py b/ts/handler_utils/hf_batch_streamer.py index 5f89dbcef0..aa76ebf3c9 100644 --- a/ts/handler_utils/hf_batch_streamer.py +++ b/ts/handler_utils/hf_batch_streamer.py @@ -40,7 +40,7 @@ def __next__(self): values = [] for iterator in self.streamer_iterators: try: - values.append(next(iterator)) + values.append({"text": next(iterator)}) except StopIteration: values.append(None) diff --git a/ts/handler_utils/utils.py b/ts/handler_utils/utils.py index ece71821d1..3361513c40 100644 --- a/ts/handler_utils/utils.py +++ b/ts/handler_utils/utils.py @@ -1,9 +1,38 @@ +import importlib import os from ts.context import Context from ts.protocol.otf_message_handler import create_predict_response +def import_class(class_name: str, module_prefix=None): + if not class_name: + raise ImportError(f"class name is not defined") + + module_name = "" + arr = class_name.rsplit(".", maxsplit=1) + if len(arr) == 2: + module_name, class_name = arr + else: + class_name = arr[0] + + if module_prefix: + module = ( + importlib.import_module(f"{module_prefix}.{module_name}") + if len(module_name) > 0 + else importlib.import_module(module_prefix) + ) + elif len(module_name) > 0: + module = importlib.import_module(module_name) + else: + raise ImportError(f"module name is not defined.") + + model_class = getattr(module, class_name, None) + if model_class is None: + raise ImportError(f"class:{class_name} not found in module:{module_name}.") + return model_class + + def send_intermediate_predict_response( ret, req_id_map, message, code, context: Context ): diff --git a/ts/tests/unit_tests/test_handler_utils.py b/ts/tests/unit_tests/test_handler_utils.py new file mode 100644 index 0000000000..96a92e7708 --- /dev/null +++ b/ts/tests/unit_tests/test_handler_utils.py @@ -0,0 +1,32 @@ +import pytest + +from ts.handler_utils.utils import import_class + + +def test_import_class_no_module_prefix(): + model_class = import_class( + class_name="transformers.LlamaTokenizer", + ) + assert "LlamaTokenizer" == model_class.__name__ + + +def test_import_class_module_prefix(): + model_class = import_class( + class_name="LlamaTokenizer", + module_prefix="transformers", + ) + assert "LlamaTokenizer" == model_class.__name__ + + +def test_import_class_no_module(): + with pytest.raises(ImportError): + model_class = import_class( + class_name="LlamaTokenizer", + ) + + +def test_import_class_no_class(): + with pytest.raises(ImportError): + model_class = import_class( + class_name="", + ) diff --git a/ts/tests/unit_tests/test_hf_batch_streamer.py b/ts/tests/unit_tests/test_hf_batch_streamer.py index 199bc3e0ed..e0d3b69885 100644 --- a/ts/tests/unit_tests/test_hf_batch_streamer.py +++ b/ts/tests/unit_tests/test_hf_batch_streamer.py @@ -23,8 +23,8 @@ def test_hf_batch_streamer(): for data in streamer: assert len(data) == 2 - output1 += data[0] - output2 += data[1] + output1 += data[0]["text"] + output2 += data[1]["text"] assert output1 == input1 assert output2 == input2 diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py new file mode 100644 index 0000000000..5f6c82fe0d --- /dev/null +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -0,0 +1,403 @@ +import logging +import os +import pathlib + +import torch +import torch_neuronx +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers_neuronx.config import ContinuousBatchingConfig, NeuronConfig +from transformers_neuronx.module import save_pretrained_split +from transformers_neuronx.sampling import select_tokens + +from ts.context import Context +from ts.handler_utils.utils import import_class +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) + + +class BaseNeuronXContinuousBatchingHandler(BaseHandler): + def __init__(self): + super().__init__() + + self.batch_size = 2 + self.max_new_tokens = 25 + self.max_length = 100 + self.tokenizer = None + self.decode_next_tokens = None + self.decode_cache_ids = None + self.decode_seq_ids = None + # the queue of seq_ids which are available for a new request + self.empty_seq_ids = [] + # map seq_id to req_id + self.seq_id_to_req_id = {} + self.model_class = None + self.tokenizer_class = None + + def initialize(self, ctx: Context): + ctx.cache = {} + model_dir = ctx.system_properties.get("model_dir") + handler_config = ctx.model_yaml_config.get("handler", {}) + model_checkpoint_dir = handler_config.get("model_checkpoint_dir", "") + + model_checkpoint_path = pathlib.Path(model_dir).joinpath(model_checkpoint_dir) + model_path = pathlib.Path(model_dir).joinpath( + handler_config.get("model_path", "") + ) + + if not model_checkpoint_path.exists(): + # Load and save the CPU model + model_cpu = AutoModelForCausalLM.from_pretrained( + str(model_path), low_cpu_mem_usage=True + ) + save_pretrained_split(model_cpu, model_checkpoint_path) + # Load and save tokenizer for the model + tokenizer = AutoTokenizer.from_pretrained( + str(model_path), return_tensors="pt", padding_side="left" + ) + tokenizer.save_pretrained(model_checkpoint_path) + + os.environ["NEURONX_CACHE"] = "on" + os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache" + os.environ[ + "NEURON_CC_FLAGS" + ] = "-O1 --model-type=transformer --enable-mixed-precision-accumulation" + + self.max_length = int(handler_config.get("max_length", self.max_length)) + self.max_new_tokens = int( + handler_config.get("max_new_tokens", self.max_new_tokens) + ) + self.batch_size = int(handler_config.get("batch_size", self.batch_size)) + + # settings for model compilation and loading + amp = handler_config.get("amp", "fp32") + tp_degree = handler_config.get("tp_degree", 6) + + # allocate "tp_degree" number of neuron cores to the worker process + os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) + try: + num_neuron_cores_available = ( + torch_neuronx.xla_impl.data_parallel.device_count() + ) + assert num_neuron_cores_available >= int(tp_degree) + except (RuntimeError, AssertionError) as error: + logger.error( + "Required number of neuron cores for tp_degree " + + str(tp_degree) + + " are not available: " + + str(error) + ) + + raise error + self._set_class(ctx) + self.tokenizer = self.tokenizer_class.from_pretrained( + model_checkpoint_path, return_tensors="pt", padding_side="left" + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=self.batch_size + ) + neuron_config = NeuronConfig(continuous_batching=continuous_batching_config) + kwargs = dict( + tp_degree=tp_degree, + amp=amp, + batch_size=self.batch_size, + n_positions=[self.max_length], + context_length_estimate=handler_config.get( + "context_length_estimate", [self.max_length] + ), + neuron_config=neuron_config, + ) + self.model = self.model_class.from_pretrained(model_checkpoint_path, **kwargs) + logger.info("Starting to compile the model") + self.model.to_neuron() + logger.info("Model has been successfully compiled") + + # 1D: [seq_id] + # an empty slot if seq_id is -1, otherwise 0 + self.decode_seq_ids = torch.full([self.batch_size], -1) + # 2D:[batch_size, next_cache_id] + self.decode_cache_ids = torch.zeros(self.batch_size, 1, dtype=torch.int64) + # 2D: [batch_size, next_token] + self.decode_next_tokens = torch.zeros(self.batch_size, 1, dtype=torch.int64) + + for seq_id in reversed(range(self.batch_size)): + self.empty_seq_ids.append(seq_id) + + logger.info("Model %s loaded successfully", ctx.model_name) + self.initialized = True + + def preprocess(self, requests): + prefill_req_ids, prefill_seq_ids, prefill_input_text, req_decode_seq_ids = ( + [], + [], + [], + [], + ) + for req_id, req_data in zip(self.context.request_ids.values(), requests): + if req_id not in self.context.cache: + prefill_req_ids.append(req_id) + seq_id = self._get_empty_seq_id() + self.seq_id_to_req_id[seq_id] = req_id + prefill_seq_ids.append(seq_id) + + data = req_data.get("data") or req_data.get("body") + if isinstance(data, (bytes, bytearray)): + data = data.decode("utf-8") + + prompt = data.get("prompt") + max_new_tokens = int(data.get("max_new_tokens", self.max_new_tokens)) + prefill_input_text.append(prompt) + + self.context.cache[req_id] = { + "seq_id": seq_id, + "stopping_criteria": self._create_stopping_criteria( + req_id=req_id, seq_id=seq_id, max_new_tokens=max_new_tokens + ), + } + else: + req_decode_seq_ids.append(self.context.cache[req_id]["seq_id"]) + + prefill_tokens = None + if len(prefill_req_ids) > 0: + prefill_tokens = self.tokenizer( + prefill_input_text, return_tensors="pt", padding=True + ) + return prefill_input_text, prefill_tokens, prefill_seq_ids, req_decode_seq_ids + + def inference(self, inputs): + prefill_input_text, prefill_tokens, prefill_seq_ids, req_decode_seq_ids = inputs + results = {} + + if len(prefill_seq_ids) > 0: + prefill_next_tokens, prefill_cache_ids = self._run_prefill( + prefill_tokens, prefill_seq_ids + ) + for i, prefill_seq_id in enumerate(prefill_seq_ids): + self._update_results( + results, + prefill_seq_id, + i, + prefill_cache_ids, + prefill_next_tokens, + prefill_tokens=prefill_tokens, + prefill_input_text=prefill_input_text, + ) + + if len(req_decode_seq_ids) > 0: + local_decode_seq_ids = torch.cat(torch.where(self.decode_seq_ids > -1)) + local_decode_cache_ids = self.decode_cache_ids[local_decode_seq_ids] + local_decode_next_tokens = self.decode_next_tokens[local_decode_seq_ids] + + local_next_tokens = self._run_decode( + local_decode_next_tokens, local_decode_cache_ids, local_decode_seq_ids + ) + + filter_prefill_seq_ids = ( + torch.isin(local_decode_seq_ids, torch.as_tensor(prefill_seq_ids)) + if len(prefill_seq_ids) > 0 + else torch.full(local_decode_seq_ids.shape, False) + ) + + local_decode_cache_ids = local_decode_cache_ids + 1 + for i, is_prefill_seq_id in enumerate(filter_prefill_seq_ids): + if not is_prefill_seq_id: + seq_id = local_decode_seq_ids[i].item() + + if seq_id in req_decode_seq_ids: + self._update_results( + results, + seq_id, + i, + local_decode_cache_ids, + local_next_tokens, + ) + else: + req_id = self._get_req_id(seq_id) + logger.warning( + f"Found request id:{req_id} with seq_id:{seq_id} in local_decode_seq_ids, but not in batch requests. Delete it" + ) + self._clean_up(seq_id, req_id) + + return [results[i] for i in self.context.request_ids.values()] + + def postprocess(self, inference_output): + self.context.stopping_criteria = [ + self.context.cache[req_id]["stopping_criteria"] + for req_id in self.context.request_ids.values() + ] + + return inference_output + + def _get_empty_seq_id(self): + if len(self.empty_seq_ids) == 0: + # clean up dead req_ids due to client disconnction + self._clean_dead_reqs() + + assert len(self.empty_seq_ids) > 0 + return self.empty_seq_ids.pop() + + def _add_empty_seq_id(self, seq_id): + self.empty_seq_ids.append(seq_id) + + def _get_seq_id(self, req_id): + seq_id = None + cache = self.context.cache.get(req_id, None) + if cache: + seq_id = cache["seq_id"] + assert seq_id is not None, "{req_id} must have seq_id" + return seq_id + + def _get_req_id(self, seq_id): + req_id = self.seq_id_to_req_id.get(seq_id, None) + assert req_id is not None + return req_id + + def _pad_to_max(self, x): + z = torch.empty(x.shape[0], self.max_length, dtype=torch.int64) + for idx, item in enumerate(x): + pad = torch.zeros(self.max_length - len(x[idx]), dtype=torch.int) + z[idx] = torch.cat((x[idx], pad)) + return z + + def _run_prefill(self, tokens, seq_ids): + input_ids, attention_mask = tokens["input_ids"], tokens["attention_mask"] + + input_ids = self._pad_to_max(input_ids) + attention_mask = self._pad_to_max(attention_mask) + + n_active_seqs, context_len = input_ids.shape + cache_ids = ( + torch.arange(context_len) + .reshape(1, context_len) + .expand(n_active_seqs, context_len) + .mul(attention_mask) + ) + with torch.inference_mode(): + logits = self.model( + input_ids, cache_ids=cache_ids, start_ids=torch.as_tensor(seq_ids) + ) + next_tokens = select_tokens(logits) + + return next_tokens, cache_ids.max(dim=1, keepdim=True).values + 1 + + def _run_decode(self, next_tokens, cache_ids, seq_ids): + with torch.inference_mode(): + logits = self.model(next_tokens, cache_ids=cache_ids, start_ids=seq_ids) + next_tokens = select_tokens(logits) + + return next_tokens + + def _clean_up(self, seq_id, req_id): + # clean up + del self.seq_id_to_req_id[seq_id] + del self.context.cache[req_id] + self.decode_seq_ids[seq_id] = -1 + self.decode_cache_ids[seq_id, :] = torch.zeros(1, dtype=torch.int64) + self.decode_next_tokens[seq_id, :] = torch.tensor( + [self.tokenizer.eos_token_id], dtype=torch.int64 + ) + # add seq_id back to self.empty_seq_ids + self._add_empty_seq_id(seq_id) + + def _clean_dead_reqs(self): + local_decode_seq_ids = torch.cat(torch.where(self.decode_seq_ids > -1)) + for _, seq_id in enumerate(local_decode_seq_ids): + seq_id_value = seq_id.item() + req_id = self._get_req_id(seq_id_value) + if req_id not in self.context.request_ids: + self._clean_up(seq_id_value, req_id) + + def _update_results( + self, + results, + seq_id, + idx, + cache_ids, + next_tokens, + prefill_tokens=None, + prefill_input_text=None, + ): + # 0: this seq_id is used for decoding if this slot is set 0 + self.decode_seq_ids[seq_id] = 0 + self.decode_cache_ids[seq_id, :] = cache_ids[idx, :] + self.decode_next_tokens[seq_id, :] = next_tokens[idx, :] + req_id = self._get_req_id(seq_id) + cur_text = self.tokenizer.decode(next_tokens[idx, :], skip_special_tokens=False) + if not (cur_text.startswith(" ") or cur_text.endswith(" ")): + if prefill_tokens is None: + previous_tokens = self.decode_next_tokens[seq_id, -1] + else: + previous_tokens = prefill_tokens["input_ids"][idx, -1] + + text = self.tokenizer.decode( + torch.cat((torch.tensor([previous_tokens]), next_tokens[idx, :])), + skip_special_tokens=False, + ) + if text[: -len(cur_text)].endswith(" "): + cur_text = " " + cur_text + + results[req_id] = { + "text": cur_text + if prefill_input_text is None + else prefill_input_text[idx] + cur_text, + "tokens": [next_tokens[idx, -1].item()], + } + + def _create_stopping_criteria(self, req_id, seq_id, max_new_tokens): + class StoppingCriteria(object): + def __init__( + self, + outer, + req_id, + seq_id, + stop_token, + max_new_tokens, + ): + self.req_id = req_id + self.seq_id = seq_id + self.outer = outer + self.max_new_tokens = max_new_tokens + self.stop_token = stop_token + + def __call__(self, res): + self.max_new_tokens -= 1 + + if self.max_new_tokens == 0 or res["tokens"][-1] == self.stop_token: + self.outer._clean_up(self.seq_id, self.req_id) + return True + return False + + return StoppingCriteria( + outer=self, + req_id=req_id, + seq_id=seq_id, + stop_token=self.tokenizer.eos_token_id, + max_new_tokens=max_new_tokens, + ) + + def _set_class(self, ctx): + handler_config = ctx.model_yaml_config.get("handler", {}) + model_class_name = handler_config.get("model_class_name", None) + + assert ( + model_class_name + ), "model_class_name not found in the section of handler in model config yaml file" + model_module_prefix = handler_config.get("model_module_prefix", None) + self.model_class = import_class( + class_name=model_class_name, + module_prefix=model_module_prefix, + ) + + tokenizer_class_name = handler_config.get("tokenizer_class_name", None) + assert ( + tokenizer_class_name + ), "tokenizer_class_name not found in the section of handler in model config yaml file" + + tokenizer_module_prefix = handler_config.get("tokenizer_module_prefix", None) + + self.tokenizer_class = import_class( + class_name=tokenizer_class_name, module_prefix=tokenizer_module_prefix + ) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 4077b5baa9..9a8bcf87e8 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1164,6 +1164,20 @@ compilable nightlies torchexportaotcompile autotune +BloomForSampling +ForSampling +GPTJForSampling +GPTNeoXForSampling +LlamaForSampling +MistralForSampling +OPTForSampling +gptj +gptneox +neox +LlamaTokenizer +bdaacb +de +fcea SDXL SDPA bfloat @@ -1184,7 +1198,6 @@ Maher's warmup SOTA FxGraphCache -TorchInductor fx locustapache resnetcppaot