forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TensorRT-LLM Engine integration (pytorch#3228)
* TensorRT-LLM Engine integration * TensorRT-LLM Engine integration * review comments * review comments * review comments * Update README.md --------- Co-authored-by: Matthias Reso <[email protected]>
- Loading branch information
Showing
6 changed files
with
223 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Llama TensorRT-LLM Engine integration with TorchServe | ||
|
||
[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) provides users with an option to build TensorRT engines for LLMs that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. | ||
|
||
## Pre-requisites | ||
|
||
TRT-LLM requires Python 3.10 | ||
This example is tested with CUDA 12.1 | ||
Once TorchServe is installed, install TensorRT-LLM using the following. | ||
This will downgrade the versions of PyTorch & Triton but this doesn't cause any issue. | ||
|
||
``` | ||
pip install tensorrt_llm==0.10.0 --extra-index-url https://pypi.nvidia.com | ||
pip install tensorrt-cu12==10.1.0 | ||
python -c "import tensorrt_llm" | ||
``` | ||
shows | ||
``` | ||
[TensorRT-LLM] TensorRT-LLM version: 0.10.0 | ||
``` | ||
|
||
## Download model from HuggingFace | ||
``` | ||
huggingface-cli login | ||
# or using an environment variable | ||
huggingface-cli login --token $HUGGINGFACE_TOKEN | ||
``` | ||
``` | ||
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct | ||
``` | ||
|
||
## Create TensorRT-LLM Engine | ||
Clone TensorRT-LLM which will be used to create the TensorRT-LLM Engine | ||
|
||
``` | ||
git clone -b v0.10.0 https://github.com/NVIDIA/TensorRT-LLM.git | ||
``` | ||
|
||
Compile the model into a TensorRT engine with model weights and a model definition written in the TensorRT-LLM Python API. | ||
|
||
``` | ||
python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16 | ||
``` | ||
``` | ||
trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --output_dir ./llama-3-8b-engine | ||
``` | ||
|
||
You can test if TensorRT-LLM Engine has been compiled correctly by running the following | ||
``` | ||
python TensorRT-LLM/examples/run.py --engine_dir ./llama-3-8b-engine --max_output_len 100 --tokenizer_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --input_text "How do I count to nine in French?" | ||
``` | ||
|
||
You should see an output as follows | ||
``` | ||
Input [Text 0]: "<|begin_of_text|>How do I count to nine in French?" | ||
Output [Text 0 Beam 0]: " Counting to nine in French is easy and fun. Here's how you can do it: | ||
One: Un | ||
Two: Deux | ||
Three: Trois | ||
Four: Quatre | ||
Five: Cinq | ||
Six: Six | ||
Seven: Sept | ||
Eight: Huit | ||
Nine: Neuf | ||
That's it! You can now count to nine in French. Just remember that the numbers one to five are similar to their English counterparts, but the numbers six to nine have different pronunciations" | ||
``` | ||
|
||
## Create model archive | ||
|
||
``` | ||
mkdir model_store | ||
torch-model-archiver --model-name llama3-8b --version 1.0 --handler trt_llm_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f | ||
mv model model_store/llama3-8b/. | ||
mv llama-3-8b-engine model_store/llama3-8b/. | ||
``` | ||
|
||
## Start TorchServe | ||
``` | ||
torchserve --start --ncs --model-store model_store --models llama3-8b --disable-token-auth | ||
``` | ||
|
||
## Run Inference | ||
``` | ||
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3-8b --prompt-text "@prompt.json" --prompt-json | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# TorchServe frontend parameters | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
maxBatchDelay: 100 | ||
responseTimeout: 1200 | ||
deviceType: "gpu" | ||
asyncCommunication: true | ||
|
||
handler: | ||
tokenizer_dir: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/" | ||
trt_llm_engine_config: | ||
engine_dir: "llama-3-8b-engine" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{"prompt": "How is the climate in San Francisco?", | ||
"temperature":0.5, | ||
"max_new_tokens": 200} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import json | ||
import logging | ||
import time | ||
|
||
import torch | ||
from tensorrt_llm.runtime import ModelRunner | ||
from transformers import AutoTokenizer | ||
|
||
from ts.handler_utils.utils import send_intermediate_predict_response | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TRTLLMHandler(BaseHandler): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.trt_llm_engine = None | ||
self.tokenizer = None | ||
self.model = None | ||
self.model_dir = None | ||
self.lora_ids = {} | ||
self.adapters = None | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
self.model_dir = ctx.system_properties.get("model_dir") | ||
|
||
trt_llm_engine_config = ctx.model_yaml_config.get("handler").get( | ||
"trt_llm_engine_config" | ||
) | ||
|
||
tokenizer_dir = ctx.model_yaml_config.get("handler").get("tokenizer_dir") | ||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
tokenizer_dir, | ||
legacy=False, | ||
padding_side="left", | ||
truncation_side="left", | ||
trust_remote_code=True, | ||
use_fast=True, | ||
) | ||
|
||
if self.tokenizer.pad_token_id is None: | ||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | ||
|
||
self.trt_llm_engine = ModelRunner.from_dir(**trt_llm_engine_config) | ||
self.initialized = True | ||
|
||
async def handle(self, data, context): | ||
start_time = time.time() | ||
|
||
metrics = context.metrics | ||
|
||
data_preprocess = await self.preprocess(data) | ||
output, input_batch = await self.inference(data_preprocess, context) | ||
output = await self.postprocess(output, input_batch, context) | ||
|
||
stop_time = time.time() | ||
metrics.add_time( | ||
"HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" | ||
) | ||
return output | ||
|
||
async def preprocess(self, requests): | ||
input_batch = [] | ||
assert len(requests) == 1, "Expecting batch_size = 1" | ||
for req_data in requests: | ||
data = req_data.get("data") or req_data.get("body") | ||
if isinstance(data, (bytes, bytearray)): | ||
data = data.decode("utf-8") | ||
|
||
prompt = data.get("prompt") | ||
temperature = data.get("temperature", 1.0) | ||
max_new_tokens = data.get("max_new_tokens", 50) | ||
input_ids = self.tokenizer.encode( | ||
prompt, add_special_tokens=True, truncation=True | ||
) | ||
input_batch.append(input_ids) | ||
|
||
input_batch = [torch.tensor(x, dtype=torch.int32) for x in input_batch] | ||
|
||
return (input_batch, temperature, max_new_tokens) | ||
|
||
async def inference(self, input_batch, context): | ||
input_ids_batch, temperature, max_new_tokens = input_batch | ||
|
||
with torch.no_grad(): | ||
outputs = self.trt_llm_engine.generate( | ||
batch_input_ids=input_ids_batch, | ||
temperature=temperature, | ||
max_new_tokens=max_new_tokens, | ||
end_id=self.tokenizer.eos_token_id, | ||
pad_id=self.tokenizer.pad_token_id, | ||
output_sequence_lengths=True, | ||
streaming=True, | ||
return_dict=True, | ||
) | ||
return outputs, input_ids_batch | ||
|
||
async def postprocess(self, inference_outputs, input_batch, context): | ||
for inference_output in inference_outputs: | ||
output_ids = inference_output["output_ids"] | ||
sequence_lengths = inference_output["sequence_lengths"] | ||
|
||
batch_size, _, _ = output_ids.size() | ||
for batch_idx in range(batch_size): | ||
output_end = sequence_lengths[batch_idx][0] | ||
outputs = output_ids[batch_idx][0][output_end - 1 : output_end].tolist() | ||
output_text = self.tokenizer.decode(outputs) | ||
send_intermediate_predict_response( | ||
[json.dumps({"text": output_text})], | ||
context.request_ids, | ||
"Result", | ||
200, | ||
context, | ||
) | ||
return [""] * len(input_batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1256,6 +1256,7 @@ parallelLevel | |
parallelType | ||
parallelization | ||
pptp | ||
TRT | ||
torchcompile | ||
HPU | ||
hpu | ||
|