-
Notifications
You must be signed in to change notification settings - Fork 870
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changes to support torch._export.aot_compile
- Loading branch information
Showing
8 changed files
with
237 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# TorchServe inference with torch._export aot_compile | ||
|
||
This example shows how to run TorchServe with Torch exported model with AOTInductor | ||
|
||
Using `torch.compile` to wrap your existing eager PyTorch model can result in out of the box speedups. However, `torch.compile` is a JIT compiler. TorchServe has been supporting `torch.compile` since PyTorch 2.0 release. In a production setting, when you have multiple instances of TorchServe, each of of your instances would `torch.compile` the model on inference. TorchServe's model archiever is not able to truly guarantee reproducibility because its a JIT compiler. | ||
|
||
In addition, the first inference request with `torch.compile` will be slow as the model needs to compile. | ||
|
||
To solve this problem, `torch.export` has an experimental API `torch._export.aot_compile` which is able to `torch.export` a torch compilable model with no graphbreaks along with AOTInductor. | ||
|
||
You can find more details [here](https://pytorch.org/docs/main/torch.compiler_aot_inductor.html) | ||
|
||
|
||
|
||
### Pre-requisites | ||
|
||
- `PyTorch >= 2.2.0` | ||
- `CUDA 12.1` | ||
|
||
Change directory to the examples directory | ||
Ex: `cd` to `examples/pt2/torch_export_aot_compile` | ||
|
||
Install PyTorch 2.2 nightlies by running | ||
``` | ||
chmod +x install_segment_anything_fast.sh | ||
source install_segment_anything_fast.sh | ||
``` | ||
|
||
### Create a Torch exported model with AOTInductor | ||
|
||
The model is saved with `.so` extension | ||
Here we are torch exporting with AOT Inductor with `max_auotune` mode. | ||
This is also making use of `dynamic_shapes` to support batch size from 1 to 32. | ||
In the code, the min batch_size is mentioned as 2 instead of 1. You can find an explanation for this [here](https://pytorch.org/docs/main/export.html#expressing-dynamism) | ||
|
||
``` | ||
python resnet18_torch_export.py | ||
``` | ||
|
||
### Create model archive | ||
|
||
``` | ||
torch-model-archiver --model-name res18-pt2 --handler image_classifier --version 1.0 --serialized-file resnet18_pt2.so --config-file model-config.yaml --extra-files ../../image_classifier/index_to_name.json | ||
mkdir model_store | ||
mv res18-pt2.mar model_store/. | ||
``` | ||
|
||
#### Start TorchServe | ||
``` | ||
torchserve --start --model-store model_store --models res18-pt2=res18-pt2.mar --ncs | ||
``` | ||
|
||
#### Run Inference | ||
|
||
``` | ||
curl http://127.0.0.1:8080/predictions/res18-pt2 -T ../../image_classifier/kitten.jpg | ||
``` | ||
|
||
produces the output | ||
|
||
``` | ||
{ | ||
"tabby": 0.4087875485420227, | ||
"tiger_cat": 0.34661102294921875, | ||
"Egyptian_cat": 0.13007202744483948, | ||
"lynx": 0.024034621194005013, | ||
"bucket": 0.011633828282356262 | ||
} | ||
``` |
13 changes: 13 additions & 0 deletions
13
examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
|
||
# Uninstall torchtext, torchdata, torch, torchvision, and torchaudio | ||
pip uninstall torchtext torchdata torch torchvision torchaudio -y | ||
|
||
# Install nightly PyTorch and torchvision from the specified index URL | ||
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed | ||
|
||
# Optional: Display the installed PyTorch and torchvision versions | ||
python -c "import torch; print('PyTorch version:', torch.__version__)" | ||
python -c "import torchvision; print('torchvision version:', torchvision.__version__)" | ||
|
||
echo "PyTorch and torchvision updated successfully!" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pt2 : "export" |
26 changes: 26 additions & 0 deletions
26
examples/pt2/torch_export_aot_compile/resnet18_torch_export.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
|
||
import torch | ||
from torchvision.models import ResNet18_Weights, resnet18 | ||
|
||
torch.set_float32_matmul_precision("high") | ||
|
||
model = resnet18(weights=ResNet18_Weights.DEFAULT) | ||
model.eval() | ||
|
||
with torch.no_grad(): | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
model = model.to(device=device) | ||
example_inputs = (torch.randn(2, 3, 224, 224, device=device),) | ||
batch_dim = torch.export.Dim("batch", min=2, max=32) | ||
so_path = torch._export.aot_compile( | ||
model, | ||
example_inputs, | ||
# Specify the first dimension of the input x as dynamic | ||
dynamic_shapes={"x": {0: batch_dim}}, | ||
# Specify the generated shared library path | ||
options={ | ||
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"), | ||
"max_autotune": True, | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
nvgpu; sys_platform != 'win32' | ||
nvgpu==0.8.0; sys_platform == 'win32' | ||
ninja==1.11.1.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from pathlib import Path | ||
|
||
import torch | ||
from pkg_resources import packaging | ||
|
||
from ts.torch_handler.image_classifier import ImageClassifier | ||
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext | ||
from ts.utils.util import load_label_mapping | ||
from ts_scripts.utils import try_and_handle | ||
|
||
CURR_FILE_PATH = Path(__file__).parent.absolute() | ||
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1] | ||
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_export_aot_compile") | ||
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg") | ||
MAPPING_DATA = REPO_ROOT_DIR.joinpath( | ||
"examples", "image_classifier", "index_to_name.json" | ||
) | ||
MODEL_SO_FILE = "resnet18_pt2.so" | ||
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml") | ||
|
||
|
||
PT_220_AVAILABLE = ( | ||
True | ||
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1") | ||
else False | ||
) | ||
|
||
EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "bucket"] | ||
TEST_CASES = [ | ||
("kitten.jpg", EXPECTED_RESULTS[0]), | ||
] | ||
|
||
|
||
import os | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def custom_working_directory(tmp_path): | ||
# Set the custom working directory | ||
custom_dir = tmp_path / "model_dir" | ||
custom_dir.mkdir() | ||
os.chdir(custom_dir) | ||
yield custom_dir | ||
# Clean up and return to the original working directory | ||
os.chdir(tmp_path) | ||
|
||
|
||
@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") | ||
def test_torch_export_aot_compile(custom_working_directory): | ||
# Get the path to the custom working directory | ||
model_dir = custom_working_directory | ||
|
||
# Construct the path to the Python script to execute | ||
script_path = os.path.join(EXAMPLE_ROOT_DIR, "resnet18_torch_export.py") | ||
|
||
# Get the .pt2 file from torch.export | ||
cmd = "python " + script_path | ||
try_and_handle(cmd) | ||
|
||
# Handler for Image classification | ||
handler = ImageClassifier() | ||
|
||
# Context definition | ||
ctx = MockContext( | ||
model_pt_file=MODEL_SO_FILE, | ||
model_dir=model_dir.as_posix(), | ||
model_file=None, | ||
model_yaml_config_file=MODEL_YAML_CFG_FILE, | ||
) | ||
|
||
torch.manual_seed(42 * 42) | ||
handler.initialize(ctx) | ||
handler.context = ctx | ||
handler.mapping = load_label_mapping(MAPPING_DATA) | ||
|
||
data = {} | ||
with open(TEST_DATA, "rb") as image: | ||
image_file = image.read() | ||
byte_array_type = bytearray(image_file) | ||
data["body"] = byte_array_type | ||
|
||
result = handler.handle([data], ctx) | ||
|
||
labels = list(result[0].keys()) | ||
|
||
assert labels == EXPECTED_RESULTS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters