Skip to content

Commit

Permalink
更新文档
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxinwei committed Jan 1, 2025
1 parent 6998d95 commit 490db6b
Show file tree
Hide file tree
Showing 36 changed files with 2,714 additions and 944 deletions.
13 changes: 0 additions & 13 deletions doc/test/tutorials/msc/index.md

This file was deleted.

3 changes: 2 additions & 1 deletion doc/tutorials/msc/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ graph/index
runner/index
pipeline/index
tools/index
plugin
plugin/index
tests/index
```
178 changes: 178 additions & 0 deletions doc/tutorials/msc/pipeline/MSCManager.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# {class}`~tvm.contrib.msc.pipeline.MSCManager`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc\n"
]
}
],
"source": [
"%cd ..\n",
"import set_env\n",
"from pathlib import Path\n",
"\n",
"temp_dir = Path(\".temp\")\n",
"temp_dir.mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def _get_torch_model(name, training=False):\n",
" \"\"\"Get model from torch vision\"\"\"\n",
"\n",
" # pylint: disable=import-outside-toplevel\n",
" try:\n",
" import torchvision\n",
"\n",
" model = getattr(torchvision.models, name)()\n",
" if training:\n",
" model = model.train()\n",
" else:\n",
" model = model.eval()\n",
" return model\n",
" except: # pylint: disable=bare-except\n",
" print(\"please install torchvision package\")\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from tvm.contrib.msc.core import utils as msc_utils\n",
"\n",
"def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):\n",
" \"\"\"Get msc config\"\"\"\n",
"\n",
" path = f'test_pipe_{model_type}_{compile_type}_{\"dynamic\" if dynamic else \"static\"}'\n",
" return {\n",
" \"workspace\": msc_utils.msc_dir(f\"{temp_dir}/{path}\", keep_history=False),\n",
" \"verbose\": \"critical\",\n",
" \"model_type\": model_type,\n",
" \"inputs\": inputs,\n",
" \"outputs\": outputs,\n",
" \"dataset\": {\"prepare\": {\"loader\": \"from_random\", \"max_iter\": 5}},\n",
" \"prepare\": {\"profile\": {\"benchmark\": {\"repeat\": 10}}},\n",
" \"baseline\": {\n",
" \"run_type\": model_type,\n",
" \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n",
" },\n",
" \"compile\": {\n",
" \"run_type\": compile_type,\n",
" \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n",
" },\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n",
"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'inputs': [{'name': 'input_0', 'shape': [1, 3, 224, 224], 'dtype': 'float32', 'layout': 'NCHW'}], 'outputs': [{'name': 'output', 'shape': [1, 1000], 'dtype': 'float32', 'layout': 'NW'}], 'nodes': {'total': 229, 'input': 1, 'nn.conv2d': 53, 'nn.batch_norm': 53, 'get_item': 53, 'nn.relu': 49, 'nn.max_pool2d': 1, 'add': 16, 'nn.adaptive_avg_pool2d': 1, 'reshape': 1, 'msc.linear_bias': 1}}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'inputs': [{'name': 'input_0', 'shape': [1, 3, 224, 224], 'dtype': 'float32', 'layout': 'NCHW'}], 'outputs': [{'name': 'output', 'shape': [1, 1000], 'dtype': 'float32', 'layout': 'NW'}], 'nodes': {'total': 229, 'input': 1, 'nn.conv2d': 53, 'nn.batch_norm': 53, 'get_item': 53, 'nn.relu': 49, 'nn.max_pool2d': 1, 'add': 16, 'nn.adaptive_avg_pool2d': 1, 'reshape': 1, 'msc.linear_bias': 1}}\n"
]
}
],
"source": [
"from tvm.contrib.msc.pipeline import MSCManager\n",
"from tvm.contrib.msc.core.utils.namespace import MSCFramework\n",
"import torch\n",
"\n",
"for compile_type in [MSCFramework.TORCH, MSCFramework.TVM]:\n",
" torch_model = _get_torch_model(\"resnet50\", False)\n",
" if torch.cuda.is_available():\n",
" torch_model = torch_model.to(torch.device(\"cuda:0\"))\n",
" config = _get_config(\n",
" MSCFramework.TORCH,\n",
" compile_type,\n",
" inputs=[[\"input_0\", [1, 3, 224, 224], \"float32\"]],\n",
" outputs=[\"output\"],\n",
" dynamic = True,\n",
" atol = 1e-1,\n",
" rtol = 1e-1,\n",
" )\n",
" pipeline = MSCManager(torch_model, config)\n",
" pipeline.run_pipe() # 运行管道\n",
" print(pipeline.get_runtime().model_info) # 打印模型信息\n",
" pipeline.destory() # 销毁管道"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "xxx",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
177 changes: 177 additions & 0 deletions doc/tutorials/msc/pipeline/TorchDynamic.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# {class}`~tvm.contrib.msc.pipeline.TorchDynamic`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc\n"
]
}
],
"source": [
"%cd ..\n",
"import set_env\n",
"from pathlib import Path\n",
"\n",
"temp_dir = Path(\".temp\")\n",
"temp_dir.mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def _get_torch_model(name, training=False):\n",
" \"\"\"Get model from torch vision\"\"\"\n",
"\n",
" # pylint: disable=import-outside-toplevel\n",
" try:\n",
" import torchvision\n",
"\n",
" model = getattr(torchvision.models, name)()\n",
" if training:\n",
" model = model.train()\n",
" else:\n",
" model = model.eval()\n",
" return model\n",
" except: # pylint: disable=bare-except\n",
" print(\"please install torchvision package\")\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from tvm.contrib.msc.core import utils as msc_utils\n",
"\n",
"def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):\n",
" \"\"\"Get msc config\"\"\"\n",
"\n",
" path = f'test_pipe_{model_type}_{compile_type}_{\"dynamic\" if dynamic else \"static\"}'\n",
" return {\n",
" \"workspace\": msc_utils.msc_dir(f\"{temp_dir}/{path}\", keep_history=False),\n",
" \"verbose\": \"critical\",\n",
" \"model_type\": model_type,\n",
" \"inputs\": inputs,\n",
" \"outputs\": outputs,\n",
" \"dataset\": {\"prepare\": {\"loader\": \"from_random\", \"max_iter\": 5}},\n",
" \"prepare\": {\"profile\": {\"benchmark\": {\"repeat\": 10}}},\n",
" \"baseline\": {\n",
" \"run_type\": model_type,\n",
" \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n",
" },\n",
" \"compile\": {\n",
" \"run_type\": compile_type,\n",
" \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n",
" },\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.\n",
" param_schemas = callee.param_schemas()\n",
"/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.\n",
" param_schemas = callee.param_schemas()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'success': False, 'info': {'prepare': {'profile': {'jit_0': '46.75 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '6.19 s(49.31%)', 'parse': '0.09 s(0.68%)', 'total': '12.55 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 162, in run_pipe\\n self.parse()\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 226, in parse\\n info, report = self._parse()\\n ^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py\", line 157, in _parse\\n info[name], report[name] = w_ctx[\"worker\"].parse()\\n ^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py\", line 320, in parse\\n self._relax_mod, _ = stage_config[\"parser\"](self._model, **parse_config)\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py\", line 119, in from_torch\\n relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 960, in from_fx\\n return TorchFXImporter().from_fx(\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 837, in from_fx\\n assert (\\nAssertionError: Unsupported function type batch_norm\\n'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[12:54:57] /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'success': False, 'info': {'prepare': {'profile': {'jit_0': '42.50 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '4.81 s(49.18%)', 'parse': '0.08 s(0.82%)', 'total': '9.78 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 162, in run_pipe\\n self.parse()\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 226, in parse\\n info, report = self._parse()\\n ^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py\", line 157, in _parse\\n info[name], report[name] = w_ctx[\"worker\"].parse()\\n ^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py\", line 320, in parse\\n self._relax_mod, _ = stage_config[\"parser\"](self._model, **parse_config)\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py\", line 119, in from_torch\\n relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 960, in from_fx\\n return TorchFXImporter().from_fx(\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 837, in from_fx\\n assert (\\nAssertionError: Unsupported function type batch_norm\\n'}\n"
]
}
],
"source": [
"from tvm.contrib.msc.core.utils.namespace import MSCFramework\n",
"from tvm.contrib.msc.pipeline import TorchDynamic\n",
"import torch\n",
"\n",
"for compile_type in [MSCFramework.TORCH, MSCFramework.TVM]:\n",
" torch_model = _get_torch_model(\"resnet50\", False)\n",
" if torch.cuda.is_available():\n",
" torch_model = torch_model.to(torch.device(\"cuda:0\"))\n",
" config = _get_config(\n",
" MSCFramework.TORCH,\n",
" compile_type,\n",
" inputs=[[\"input_0\", [1, 3, 224, 224], \"float32\"]],\n",
" outputs=[\"output\"],\n",
" dynamic = True,\n",
" atol = 1e-1,\n",
" rtol = 1e-1,\n",
" )\n",
" pipeline = TorchDynamic(torch_model, config)\n",
" pipeline.run_pipe() # 运行管道\n",
" print(pipeline.report) # 打印模型信息\n",
" pipeline.destory() # 销毁管道"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "ai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 2 additions & 1 deletion doc/tutorials/msc/pipeline/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
:hidden:
intro
manager
MSCManager
TorchDynamic
```
Loading

0 comments on commit 490db6b

Please sign in to comment.