Skip to content

Commit

Permalink
Changes to support torch._export.aot_compile
Browse files Browse the repository at this point in the history
  • Loading branch information
agunapal committed Dec 5, 2023
1 parent f3a2267 commit 8754eb2
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 0 deletions.
19 changes: 19 additions & 0 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,22 @@ print(extra_files['foo.txt'])
# from inference()
print(ep(torch.randn(5)))
```

## torch._export.aot_compile

Using `torch.compile` to wrap your existing eager PyTorch model can result in out of the box speedups. However, `torch.compile` is a JIT compiler. TorchServe has been supporting `torch.compile` since PyTorch 2.0 release. In a production setting, when you have multiple instances of TorchServe, each of of your instances would `torch.compile` the model on inference. TorchServe's model archiever is not able to truly guarantee reproducibility because its a JIT compiler.

In addition, the first inference request with `torch.compile` will be slow as the model needs to compile.

To solve this problem, `torch.export` has an experimental API `torch._export.aot_compile` which is able to `torch.export` a torch compilable model with no graphbreaks along with AOTInductor.

You can find more details [here](https://pytorch.org/docs/main/torch.compiler_aot_inductor.html)


This is an experimental API and needs PyTorch 2.2 nightlies
To achieve this, add the following config in your `model-config.yaml`

```yaml
pt2: "export"
```
You can find an example [here](./torch_export_aot_compile/README.md)
69 changes: 69 additions & 0 deletions examples/pt2/torch_export_aot_compile/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# TorchServe inference with torch._export aot_compile

This example shows how to run TorchServe with Torch exported model with AOTInductor

Using `torch.compile` to wrap your existing eager PyTorch model can result in out of the box speedups. However, `torch.compile` is a JIT compiler. TorchServe has been supporting `torch.compile` since PyTorch 2.0 release. In a production setting, when you have multiple instances of TorchServe, each of of your instances would `torch.compile` the model on inference. TorchServe's model archiever is not able to truly guarantee reproducibility because its a JIT compiler.

In addition, the first inference request with `torch.compile` will be slow as the model needs to compile.

To solve this problem, `torch.export` has an experimental API `torch._export.aot_compile` which is able to `torch.export` a torch compilable model with no graphbreaks along with AOTInductor.

You can find more details [here](https://pytorch.org/docs/main/torch.compiler_aot_inductor.html)



### Pre-requisites

- `PyTorch >= 2.2.0`
- `CUDA 12.1`

Change directory to the examples directory
Ex: `cd` to `examples/pt2/torch_export_aot_compile`

Install PyTorch 2.2 nightlies by running
```
chmod +x install_segment_anything_fast.sh
source install_segment_anything_fast.sh
```

### Create a Torch exported model with AOTInductor

The model is saved with `.so` extension
Here we are torch exporting with AOT Inductor with `max_auotune` mode.
This is also making use of `dynamic_shapes` to support batch size from 1 to 32.
In the code, the min batch_size is mentioned as 2 instead of 1. You can find an explanation for this [here](https://pytorch.org/docs/main/export.html#expressing-dynamism)

```
python resnet18_torch_export.py
```

### Create model archive

```
torch-model-archiver --model-name res18-pt2 --handler image_classifier --version 1.0 --serialized-file resnet18_pt2.so --config-file model-config.yaml --extra-files ../../image_classifier/index_to_name.json
mkdir model_store
mv res18-pt2.mar model_store/.
```

#### Start TorchServe
```
torchserve --start --model-store model_store --models res18-pt2=res18-pt2.mar --ncs
```

#### Run Inference

```
curl http://127.0.0.1:8080/predictions/res18-pt2 -T ../../image_classifier/kitten.jpg
```

produces the output

```
{
"tabby": 0.4087875485420227,
"tiger_cat": 0.34661102294921875,
"Egyptian_cat": 0.13007202744483948,
"lynx": 0.024034621194005013,
"bucket": 0.011633828282356262
}
```
13 changes: 13 additions & 0 deletions examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

# Uninstall torchtext, torchdata, torch, torchvision, and torchaudio
pip uninstall torchtext torchdata torch torchvision torchaudio -y

# Install nightly PyTorch and torchvision from the specified index URL
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed

# Optional: Display the installed PyTorch and torchvision versions
python -c "import torch; print('PyTorch version:', torch.__version__)"
python -c "import torchvision; print('torchvision version:', torchvision.__version__)"

echo "PyTorch and torchvision updated successfully!"
1 change: 1 addition & 0 deletions examples/pt2/torch_export_aot_compile/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pt2 : "export"
26 changes: 26 additions & 0 deletions examples/pt2/torch_export_aot_compile/resnet18_torch_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

import torch
from torchvision.models import ResNet18_Weights, resnet18

torch.set_float32_matmul_precision("high")

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()

with torch.no_grad():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
batch_dim = torch.export.Dim("batch", min=2, max=32)
so_path = torch._export.aot_compile(
model,
example_inputs,
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
# Specify the generated shared library path
options={
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
"max_autotune": True,
},
)
1 change: 1 addition & 0 deletions requirements/common_gpu.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
nvgpu; sys_platform != 'win32'
nvgpu==0.8.0; sys_platform == 'win32'
ninja==1.11.1.1
88 changes: 88 additions & 0 deletions test/pytest/test_torch_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pathlib import Path

import torch
from pkg_resources import packaging

from ts.torch_handler.image_classifier import ImageClassifier
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
from ts.utils.util import load_label_mapping
from ts_scripts.utils import try_and_handle

CURR_FILE_PATH = Path(__file__).parent.absolute()
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1]
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_export_aot_compile")
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg")
MAPPING_DATA = REPO_ROOT_DIR.joinpath(
"examples", "image_classifier", "index_to_name.json"
)
MODEL_SO_FILE = "resnet18_pt2.so"
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")


PT_220_AVAILABLE = (
True
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1")
else False
)

EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "bucket"]
TEST_CASES = [
("kitten.jpg", EXPECTED_RESULTS[0]),
]


import os

import pytest


@pytest.fixture
def custom_working_directory(tmp_path):
# Set the custom working directory
custom_dir = tmp_path / "model_dir"
custom_dir.mkdir()
os.chdir(custom_dir)
yield custom_dir
# Clean up and return to the original working directory
os.chdir(tmp_path)


@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0")
def test_torch_export_aot_compile(custom_working_directory):
# Get the path to the custom working directory
model_dir = custom_working_directory

# Construct the path to the Python script to execute
script_path = os.path.join(EXAMPLE_ROOT_DIR, "resnet18_torch_export.py")

# Get the .pt2 file from torch.export
cmd = "python " + script_path
try_and_handle(cmd)

# Handler for Image classification
handler = ImageClassifier()

# Context definition
ctx = MockContext(
model_pt_file=MODEL_SO_FILE,
model_dir=model_dir.as_posix(),
model_file=None,
model_yaml_config_file=MODEL_YAML_CFG_FILE,
)

torch.manual_seed(42 * 42)
handler.initialize(ctx)
handler.context = ctx
handler.mapping = load_label_mapping(MAPPING_DATA)

data = {}
with open(TEST_DATA, "rb") as image:
image_file = image.read()
byte_array_type = bytearray(image_file)
data["body"] = byte_array_type

result = handler.handle([data], ctx)

labels = list(result[0].keys())

assert labels == EXPECTED_RESULTS
20 changes: 20 additions & 0 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
)
PT2_AVAILABLE = False

if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"):
PT220_AVAILABLE = True
else:
PT220_AVAILABLE = False

if os.environ.get("TS_IPEX_ENABLE", "false") == "true":
try:
Expand Down Expand Up @@ -151,6 +155,12 @@ def initialize(self, context):
self.map_location = "cpu"
self.device = torch.device(self.map_location)

TORCH_EXPORT_AVAILABLE = False
if hasattr(self, "model_yaml_config") and "pt2" in self.model_yaml_config:
pt2_value = self.model_yaml_config["pt2"]
if pt2_value == "export" and PT220_AVAILABLE:
TORCH_EXPORT_AVAILABLE = True

self.manifest = context.manifest

model_dir = properties.get("model_dir")
Expand Down Expand Up @@ -180,6 +190,13 @@ def initialize(self, context):
self.model = setup_ort_session(self.model_pt_path, self.map_location)
logger.info("Succesfully setup ort session")

elif self.model_pt_path.endswith(".so") and TORCH_EXPORT_AVAILABLE:
from ts.handler_utils.torch_export.load_model import load_exported_model

self.model = self._load_torch_export_aot_compile(self.model_pt_path)
logger.warning(
"torch._export is an experimental feature! Succesfully loaded torch exported model."
)
else:
raise RuntimeError("No model weights could be loaded")

Expand Down Expand Up @@ -234,6 +251,9 @@ def initialize(self, context):

self.initialized = True

def _load_torch_export_aot_compile(self, model_so_path):
return load_exported_model(model_so_path, self.map_location)

def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Expand Down

0 comments on commit 8754eb2

Please sign in to comment.