diff --git a/doc/test/tutorials/msc/index.md b/doc/test/tutorials/msc/index.md deleted file mode 100644 index f80c0e30..00000000 --- a/doc/test/tutorials/msc/index.md +++ /dev/null @@ -1,13 +0,0 @@ -# MSC - -```{toctree} - - -translate-relay -translate-relax -translate-torch -translate-tensorflow -translate-tensorrt -translate -transform -``` diff --git a/doc/tutorials/msc/index.md b/doc/tutorials/msc/index.md index 33d62450..526c080f 100644 --- a/doc/tutorials/msc/index.md +++ b/doc/tutorials/msc/index.md @@ -8,5 +8,6 @@ graph/index runner/index pipeline/index tools/index -plugin +plugin/index +tests/index ``` diff --git a/doc/tutorials/msc/pipeline/MSCManager.ipynb b/doc/tutorials/msc/pipeline/MSCManager.ipynb new file mode 100644 index 00000000..788dae7a --- /dev/null +++ b/doc/tutorials/msc/pipeline/MSCManager.ipynb @@ -0,0 +1,178 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# {class}`~tvm.contrib.msc.pipeline.MSCManager`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc\n" + ] + } + ], + "source": [ + "%cd ..\n", + "import set_env\n", + "from pathlib import Path\n", + "\n", + "temp_dir = Path(\".temp\")\n", + "temp_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def _get_torch_model(name, training=False):\n", + " \"\"\"Get model from torch vision\"\"\"\n", + "\n", + " # pylint: disable=import-outside-toplevel\n", + " try:\n", + " import torchvision\n", + "\n", + " model = getattr(torchvision.models, name)()\n", + " if training:\n", + " model = model.train()\n", + " else:\n", + " model = model.eval()\n", + " return model\n", + " except: # pylint: disable=bare-except\n", + " print(\"please install torchvision package\")\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from tvm.contrib.msc.core import utils as msc_utils\n", + "\n", + "def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):\n", + " \"\"\"Get msc config\"\"\"\n", + "\n", + " path = f'test_pipe_{model_type}_{compile_type}_{\"dynamic\" if dynamic else \"static\"}'\n", + " return {\n", + " \"workspace\": msc_utils.msc_dir(f\"{temp_dir}/{path}\", keep_history=False),\n", + " \"verbose\": \"critical\",\n", + " \"model_type\": model_type,\n", + " \"inputs\": inputs,\n", + " \"outputs\": outputs,\n", + " \"dataset\": {\"prepare\": {\"loader\": \"from_random\", \"max_iter\": 5}},\n", + " \"prepare\": {\"profile\": {\"benchmark\": {\"repeat\": 10}}},\n", + " \"baseline\": {\n", + " \"run_type\": model_type,\n", + " \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n", + " },\n", + " \"compile\": {\n", + " \"run_type\": compile_type,\n", + " \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n", + " },\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n", + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'inputs': [{'name': 'input_0', 'shape': [1, 3, 224, 224], 'dtype': 'float32', 'layout': 'NCHW'}], 'outputs': [{'name': 'output', 'shape': [1, 1000], 'dtype': 'float32', 'layout': 'NW'}], 'nodes': {'total': 229, 'input': 1, 'nn.conv2d': 53, 'nn.batch_norm': 53, 'get_item': 53, 'nn.relu': 49, 'nn.max_pool2d': 1, 'add': 16, 'nn.adaptive_avg_pool2d': 1, 'reshape': 1, 'msc.linear_bias': 1}}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'inputs': [{'name': 'input_0', 'shape': [1, 3, 224, 224], 'dtype': 'float32', 'layout': 'NCHW'}], 'outputs': [{'name': 'output', 'shape': [1, 1000], 'dtype': 'float32', 'layout': 'NW'}], 'nodes': {'total': 229, 'input': 1, 'nn.conv2d': 53, 'nn.batch_norm': 53, 'get_item': 53, 'nn.relu': 49, 'nn.max_pool2d': 1, 'add': 16, 'nn.adaptive_avg_pool2d': 1, 'reshape': 1, 'msc.linear_bias': 1}}\n" + ] + } + ], + "source": [ + "from tvm.contrib.msc.pipeline import MSCManager\n", + "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n", + "import torch\n", + "\n", + "for compile_type in [MSCFramework.TORCH, MSCFramework.TVM]:\n", + " torch_model = _get_torch_model(\"resnet50\", False)\n", + " if torch.cuda.is_available():\n", + " torch_model = torch_model.to(torch.device(\"cuda:0\"))\n", + " config = _get_config(\n", + " MSCFramework.TORCH,\n", + " compile_type,\n", + " inputs=[[\"input_0\", [1, 3, 224, 224], \"float32\"]],\n", + " outputs=[\"output\"],\n", + " dynamic = True,\n", + " atol = 1e-1,\n", + " rtol = 1e-1,\n", + " )\n", + " pipeline = MSCManager(torch_model, config)\n", + " pipeline.run_pipe() # 运行管道\n", + " print(pipeline.get_runtime().model_info) # 打印模型信息\n", + " pipeline.destory() # 销毁管道" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xxx", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/tutorials/msc/pipeline/TorchDynamic.ipynb b/doc/tutorials/msc/pipeline/TorchDynamic.ipynb new file mode 100644 index 00000000..cd0a1763 --- /dev/null +++ b/doc/tutorials/msc/pipeline/TorchDynamic.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# {class}`~tvm.contrib.msc.pipeline.TorchDynamic`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc\n" + ] + } + ], + "source": [ + "%cd ..\n", + "import set_env\n", + "from pathlib import Path\n", + "\n", + "temp_dir = Path(\".temp\")\n", + "temp_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def _get_torch_model(name, training=False):\n", + " \"\"\"Get model from torch vision\"\"\"\n", + "\n", + " # pylint: disable=import-outside-toplevel\n", + " try:\n", + " import torchvision\n", + "\n", + " model = getattr(torchvision.models, name)()\n", + " if training:\n", + " model = model.train()\n", + " else:\n", + " model = model.eval()\n", + " return model\n", + " except: # pylint: disable=bare-except\n", + " print(\"please install torchvision package\")\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from tvm.contrib.msc.core import utils as msc_utils\n", + "\n", + "def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):\n", + " \"\"\"Get msc config\"\"\"\n", + "\n", + " path = f'test_pipe_{model_type}_{compile_type}_{\"dynamic\" if dynamic else \"static\"}'\n", + " return {\n", + " \"workspace\": msc_utils.msc_dir(f\"{temp_dir}/{path}\", keep_history=False),\n", + " \"verbose\": \"critical\",\n", + " \"model_type\": model_type,\n", + " \"inputs\": inputs,\n", + " \"outputs\": outputs,\n", + " \"dataset\": {\"prepare\": {\"loader\": \"from_random\", \"max_iter\": 5}},\n", + " \"prepare\": {\"profile\": {\"benchmark\": {\"repeat\": 10}}},\n", + " \"baseline\": {\n", + " \"run_type\": model_type,\n", + " \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n", + " },\n", + " \"compile\": {\n", + " \"run_type\": compile_type,\n", + " \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n", + " },\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.\n", + " param_schemas = callee.param_schemas()\n", + "/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.\n", + " param_schemas = callee.param_schemas()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'success': False, 'info': {'prepare': {'profile': {'jit_0': '46.75 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '6.19 s(49.31%)', 'parse': '0.09 s(0.68%)', 'total': '12.55 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 162, in run_pipe\\n self.parse()\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 226, in parse\\n info, report = self._parse()\\n ^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py\", line 157, in _parse\\n info[name], report[name] = w_ctx[\"worker\"].parse()\\n ^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py\", line 320, in parse\\n self._relax_mod, _ = stage_config[\"parser\"](self._model, **parse_config)\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py\", line 119, in from_torch\\n relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 960, in from_fx\\n return TorchFXImporter().from_fx(\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 837, in from_fx\\n assert (\\nAssertionError: Unsupported function type batch_norm\\n'}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[12:54:57] /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'success': False, 'info': {'prepare': {'profile': {'jit_0': '42.50 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '4.81 s(49.18%)', 'parse': '0.08 s(0.82%)', 'total': '9.78 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 162, in run_pipe\\n self.parse()\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 226, in parse\\n info, report = self._parse()\\n ^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py\", line 157, in _parse\\n info[name], report[name] = w_ctx[\"worker\"].parse()\\n ^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py\", line 320, in parse\\n self._relax_mod, _ = stage_config[\"parser\"](self._model, **parse_config)\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py\", line 119, in from_torch\\n relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 960, in from_fx\\n return TorchFXImporter().from_fx(\\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\\n File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 837, in from_fx\\n assert (\\nAssertionError: Unsupported function type batch_norm\\n'}\n" + ] + } + ], + "source": [ + "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n", + "from tvm.contrib.msc.pipeline import TorchDynamic\n", + "import torch\n", + "\n", + "for compile_type in [MSCFramework.TORCH, MSCFramework.TVM]:\n", + " torch_model = _get_torch_model(\"resnet50\", False)\n", + " if torch.cuda.is_available():\n", + " torch_model = torch_model.to(torch.device(\"cuda:0\"))\n", + " config = _get_config(\n", + " MSCFramework.TORCH,\n", + " compile_type,\n", + " inputs=[[\"input_0\", [1, 3, 224, 224], \"float32\"]],\n", + " outputs=[\"output\"],\n", + " dynamic = True,\n", + " atol = 1e-1,\n", + " rtol = 1e-1,\n", + " )\n", + " pipeline = TorchDynamic(torch_model, config)\n", + " pipeline.run_pipe() # 运行管道\n", + " print(pipeline.report) # 打印模型信息\n", + " pipeline.destory() # 销毁管道" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ai", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/tutorials/msc/pipeline/index.md b/doc/tutorials/msc/pipeline/index.md index d4186314..094e233a 100644 --- a/doc/tutorials/msc/pipeline/index.md +++ b/doc/tutorials/msc/pipeline/index.md @@ -6,5 +6,6 @@ :hidden: intro -manager +MSCManager +TorchDynamic ``` diff --git a/doc/tutorials/msc/pipeline/manager.ipynb b/doc/tutorials/msc/pipeline/manager.ipynb deleted file mode 100644 index 98cfa152..00000000 --- a/doc/tutorials/msc/pipeline/manager.ipynb +++ /dev/null @@ -1,182 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# {class}`~tvm.contrib.msc.pipeline.manager.MSCManager`\n", - "\n", - "{class}`~tvm.contrib.msc.pipeline.manager.MSCManager` 将 MSCGraph(s) 与不同的框架连接起来,它封装了一些常用的方法并管理 MSCTools。" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc\n" - ] - } - ], - "source": [ - "%cd ..\n", - "import set_env\n", - "from pathlib import Path\n", - "\n", - "temp_dir = Path(\".temp\")\n", - "temp_dir.mkdir(exist_ok=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "构建前端模型:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n", - "from torch import fx\n", - "import tvm\n", - "from tvm import relax\n", - "from tvm.relax.frontend.torch import from_fx\n", - "\n", - "class M(torch.nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.conv = torch.nn.Conv2d(3, 6, 1, bias=False)\n", - " self.relu = torch.nn.ReLU()\n", - "\n", - " def forward(self, data):\n", - " x = self.conv(data)\n", - " return self.relu(x)\n", - "\n", - "input_info = [([1, 3, 224, 224], \"float32\")]\n", - "with torch.no_grad():\n", - " torch_fx_model = fx.symbolic_trace(M())\n", - " mod = from_fx(torch_fx_model, input_info, keep_params_as_input=False)\n", - "mod, params = relax.frontend.detach_params(mod)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "from tvm.contrib.msc.core.transform import msc_transform\n", - "from tvm.contrib.msc.core.runtime import create_runtime_manager\n", - "from tvm.contrib.msc.core.tools import create_tool, MSC_TOOL\n", - "\n", - "# build runtime manager from module and mscgraphs\n", - "optimized_mod, msc_graph, msc_config = msc_transform(mod, params)\n", - "rt_manager = create_runtime_manager(optimized_mod, params, msc_config)\n", - "rt_manager.create_tool(MSC_TOOL.QUANTIZE, quantize_config)\n", - "quantizer = rt_manager.get_tool(MSC_TOOL.QUANTIZE)\n", - "\n", - "rt_manager.load_model()\n", - "# calibrate the datas with float model\n", - "while not quantizer.calibrated:\n", - " for datas in calibrate_datas:\n", - " rt_manager.run(datas)\n", - " quantizer.calibrate()\n", - "quantizer.save_strategy(strategy_file)\n", - "\n", - "# load again the quantized model, without loading the weights\n", - "rt_manager.load_model(reuse_weights=True)\n", - "outputs = rt_manager.run(sample_datas)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "MSCManager将编译流程进行封装,暴露出一个面向用户的接口。使用方式类似:\n", - "```python\n", - "improt torchvision\n", - "from tvm.contrib.msc.pipeline import MSCManager\n", - "\n", - "model = trochvision.models.resnet50()\n", - "# define your config\n", - "manager = MSCManager(model, config)\n", - "runner = manager.run_pipe()\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## MSCWrapper\n", - "\n", - "MSCWrapper是对MSCManger的进一步封装,主要作用是将MSC的编译过程变成用户友好的工具连接口。其使用方式和MSCManager基本相同,如\n", - "```python\n", - "model = TorchWrapper(model, config)\n", - "\n", - "# export to dump meta model\n", - "# model.export()\n", - "\n", - "# optimize the model with quantizer(PTQ)\n", - "model.optimize()\n", - "acc = eval_model(model, testloader, max_iter=args.test_iter)\n", - "\n", - "# train the model with quantizer(QAT)\n", - "optimizer = optim.Adam(model.parameters(), lr=0.0000001, weight_decay=0.08)\n", - "for ep in range(args.train_epoch):\n", - " train_model(model, trainloader, optimizer, max_iter=args.train_iter)\n", - " acc = eval_model(model, testloader, max_iter=args.test_iter)\n", - "\n", - "# export to dump checkpoint model\n", - "# model.export()\n", - "\n", - "# compile the model\n", - "model.compile(bind_params=True)\n", - "acc = eval_model(model, testloader, max_iter=args.test_iter)\n", - "\n", - "# export to dump compiled model\n", - "# model.export()\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "使用example尚未合入,合入后更新文档。\n", - "\n", - "MSCWrapper包裹的model保留原model所有的方法,可以用于训练或者评测过程,但调用MSCWrapper.optimize或MSCWrapper.compile之后model已经被替换成了优化之后或编译得到的模型,只在输入输出格式上进行适配支持原始模型对应格式的数据类型。" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "xxx", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/doc/tutorials/msc/plugin.ipynb b/doc/tutorials/msc/plugin.ipynb deleted file mode 100644 index 91f08f9f..00000000 --- a/doc/tutorials/msc/plugin.ipynb +++ /dev/null @@ -1,406 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# MSC Plugin" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "参考:[【我与TVM二三事 后篇(5)】MSC之PluginBuilder](https://zhuanlan.zhihu.com/p/681450076)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Licensed to the Apache Software Foundation (ASF) under one\n", - "# or more contributor license agreements. See the NOTICE file\n", - "# distributed with this work for additional information\n", - "# regarding copyright ownership. The ASF licenses this file\n", - "# to you under the Apache License, Version 2.0 (the\n", - "# \"License\"); you may not use this file except in compliance\n", - "# with the License. You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing,\n", - "# software distributed under the License is distributed on an\n", - "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", - "# KIND, either express or implied. See the License for the\n", - "# specific language governing permissions and limitations\n", - "# under the License.\n", - "\n", - "\"\"\" Test Plugin in MSC. \"\"\"\n", - "\n", - "import numpy as np\n", - "\n", - "import torch\n", - "from torch import nn\n", - "\n", - "import tvm.testing\n", - "from tvm import relax\n", - "from tvm.relax.transform import BindParams\n", - "from tvm.script import relax as R\n", - "from tvm.contrib.msc.pipeline import MSCManager\n", - "from tvm.contrib.msc.plugin import build_plugins\n", - "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n", - "from tvm.contrib.msc.core import utils as msc_utils\n", - "\n", - "\n", - "def _get_externs_header():\n", - " \"\"\"Get the header source for externs\"\"\"\n", - "\n", - " return \"\"\"#ifndef EXTERNS_H_\n", - "#define EXTERNS_H_\n", - "\n", - "#include \"plugin_base.h\"\n", - "\n", - "#ifdef PLUGIN_ENABLE_CUDA\n", - "#include \n", - "#endif\n", - "\n", - "namespace tvm {\n", - "namespace contrib {\n", - "namespace msc {\n", - "namespace plugin {\n", - "\n", - "template \n", - "std::vector my_relu_infer(const std::vector& inputs, const TAttr& attrs,\n", - " bool is_runtime) {\n", - " std::vector outputs;\n", - " outputs.push_back(MetaTensor(inputs[0].shape(), inputs[0].data_type(), inputs[0].layout()));\n", - " return outputs;\n", - "}\n", - "\n", - "template \n", - "void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val);\n", - "\n", - "template \n", - "void my_relu_cpu_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs) {\n", - " my_relu_cpu_kernel(input, output, T(attrs.max_val));\n", - "}\n", - "\n", - "#ifdef PLUGIN_ENABLE_CUDA\n", - "template \n", - "void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val,\n", - " const cudaStream_t& stream);\n", - "\n", - "template \n", - "void my_relu_cuda_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs,\n", - " const cudaStream_t& stream) {\n", - " my_relu_cuda_kernel(input, output, T(attrs.max_val), stream);\n", - "}\n", - "#endif\n", - "\n", - "} // namespace plugin\n", - "} // namespace msc\n", - "} // namespace contrib\n", - "} // namespace tvm\n", - "#endif // EXTERNS_H_\n", - "\"\"\"\n", - "\n", - "\n", - "def _get_externs_cc():\n", - " \"\"\"Get externs cc source\"\"\"\n", - " return \"\"\"#include \"externs.h\"\n", - "\n", - "namespace tvm {\n", - "namespace contrib {\n", - "namespace msc {\n", - "namespace plugin {\n", - "\n", - "template \n", - "void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val) {\n", - " const T* input_data = input.const_data();\n", - " T* output_data = output.data();\n", - " for (size_t i = 0; i < output.size(); i++) {\n", - " if (input_data[i] >= max_val) {\n", - " output_data[i] = max_val;\n", - " } else if (input_data[i] <= 0) {\n", - " output_data[i] = 0;\n", - " } else {\n", - " output_data[i] = input_data[i];\n", - " }\n", - " }\n", - "}\n", - "\n", - "template void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output,\n", - " float max_val);\n", - "\n", - "} // namespace plugin\n", - "} // namespace msc\n", - "} // namespace contrib\n", - "} // namespace tvm\n", - "\"\"\"\n", - "\n", - "\n", - "def _get_externs_cu():\n", - " \"\"\"Get externs cu source\"\"\"\n", - "\n", - " return \"\"\"#include \"externs.h\"\n", - "\n", - "#define CU1DBLOCK 256\n", - "#define KERNEL_LOOP(i, n) \\\n", - " for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)\n", - "\n", - "namespace tvm {\n", - "namespace contrib {\n", - "namespace msc {\n", - "namespace plugin {\n", - "\n", - "inline int n_blocks(int size, int block_size) {\n", - " return size / block_size + (size % block_size == 0 ? 0 : 1);\n", - "}\n", - "\n", - "template \n", - "__global__ static void _my_relu(const T* src, T* dst, T max_val, int n) {\n", - " KERNEL_LOOP(i, n) {\n", - " if (src[i] >= max_val) {\n", - " dst[i] = max_val;\n", - " } else if (src[i] <= 0) {\n", - " dst[i] = 0;\n", - " } else {\n", - " dst[i] = src[i];\n", - " }\n", - " }\n", - "}\n", - "\n", - "template \n", - "void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val,\n", - " const cudaStream_t& stream) {\n", - " const T* input_data = input.const_data();\n", - " T* output_data = output.data();\n", - " dim3 Bl(CU1DBLOCK);\n", - " dim3 Gr(n_blocks(output.size(), CU1DBLOCK));\n", - " _my_relu<<>>(input_data, output_data, max_val, output.size());\n", - "}\n", - "\n", - "template void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output,\n", - " float max_val, const cudaStream_t& stream);\n", - "\n", - "} // namespace plugin\n", - "} // namespace msc\n", - "} // namespace contrib\n", - "} // namespace tvm\n", - "\"\"\"\n", - "\n", - "\n", - "def _create_plugin(externs_dir):\n", - " \"\"\"Create sources under source folder\"\"\"\n", - " with open(externs_dir.relpath(\"externs.h\"), \"w\") as f:\n", - " f.write(_get_externs_header())\n", - " with open(externs_dir.relpath(\"externs.cc\"), \"w\") as f:\n", - " f.write(_get_externs_cc())\n", - " with open(externs_dir.relpath(\"externs.cu\"), \"w\") as f:\n", - " f.write(_get_externs_cu())\n", - " return {\n", - " \"MyRelu\": {\n", - " \"inputs\": [{\"name\": \"input\", \"dtype\": \"T\"}],\n", - " \"outputs\": [{\"name\": \"output\", \"dtype\": \"T\"}],\n", - " \"attrs\": [{\"name\": \"max_val\", \"type\": \"float\"}],\n", - " \"support_dtypes\": {\"T\": [\"float\"]},\n", - " \"externs\": {\n", - " \"infer_output\": {\"name\": \"my_relu_infer\", \"header\": \"externs.h\"},\n", - " \"cpu_compute\": {\n", - " \"name\": \"my_relu_cpu_compute\",\n", - " \"header\": \"externs.h\",\n", - " \"source\": \"externs.cc\",\n", - " },\n", - " \"cuda_compute\": {\n", - " \"name\": \"my_relu_cuda_compute\",\n", - " \"header\": \"externs.h\",\n", - " \"source\": \"externs.cu\",\n", - " },\n", - " },\n", - " }\n", - " }\n", - "\n", - "\n", - "def _get_torch_model(torch_manager):\n", - " \"\"\"Build model with plugin\"\"\"\n", - "\n", - " class MyModel(nn.Module):\n", - " \"\"\"Test model with plugin\"\"\"\n", - "\n", - " def __init__(self):\n", - " super(MyModel, self).__init__()\n", - " self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)\n", - " self.relu = torch_manager.MyRelu(max_val=0.5)\n", - " self.maxpool = nn.MaxPool2d(kernel_size=[1, 1])\n", - "\n", - " def forward(self, data):\n", - " data = self.conv(data)\n", - " data = self.relu(data)\n", - " return self.maxpool(data)\n", - "\n", - " return MyModel()\n", - "\n", - "\n", - "def _get_tvm_model(tvm_manager):\n", - " \"\"\"Build model with plugin\"\"\"\n", - "\n", - " block_builder = relax.BlockBuilder()\n", - " weights = np.random.rand(6, 3, 7, 7).astype(\"float32\")\n", - " data = relax.Var(\"data\", R.Tensor((1, 3, 224, 224), \"float32\"))\n", - " weight = relax.Var(\"weight\", R.Tensor(weights.shape, weights.dtype.name))\n", - " inputs = [data, weight]\n", - " with block_builder.function(name=\"main\", params=inputs.copy()):\n", - " with block_builder.dataflow():\n", - " data = relax.op.nn.conv2d(data, weight)\n", - " data = block_builder.emit(data, \"conv2d\")\n", - " data = tvm_manager.MyRelu(data, max_val=0.5)\n", - " data = block_builder.emit(data, \"relu\")\n", - " data = relax.op.nn.max_pool2d(data)\n", - " data = block_builder.emit(data, \"max_pool2d\")\n", - " data = block_builder.emit_output(data)\n", - " block_builder.emit_func_output(data)\n", - " mod = block_builder.finalize()\n", - " return BindParams(\"main\", {\"weight\": tvm.nd.array(weights)})(mod)\n", - "\n", - "\n", - "def _build_plugin(frameworks, plugin_root):\n", - " externs_dir = plugin_root.create_dir(\"externs\")\n", - " install_dir = plugin_root.create_dir(\"install\")\n", - " plugin = _create_plugin(externs_dir)\n", - " managers = build_plugins(plugin, frameworks, install_dir, externs_dir=externs_dir)\n", - " return managers\n", - "\n", - "\n", - "def _run_relax(relax_mod, target_name, data):\n", - " target = tvm.target.Target(target_name)\n", - " relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)\n", - " if target_name == \"cuda\":\n", - " with target:\n", - " relax_mod = tvm.tir.transform.DefaultGPUSchedule()(relax_mod)\n", - " device = tvm.cuda()\n", - " else:\n", - " device = tvm.cpu()\n", - " with tvm.transform.PassContext(opt_level=3):\n", - " relax_exec = tvm.relax.build(relax_mod, target)\n", - " runnable = tvm.relax.VirtualMachine(relax_exec, device)\n", - " data = tvm.nd.array(data, device)\n", - " return runnable[\"main\"](data).asnumpy()\n", - "\n", - "\n", - "def _test_tvm_plugin(manager, target):\n", - " \"\"\"Test plugin in tvm\"\"\"\n", - "\n", - " model = _get_tvm_model(manager)\n", - " data = np.random.rand(1, 3, 224, 224).astype(\"float32\")\n", - " outputs = _run_relax(model, target, data)\n", - " assert outputs.min() >= 0 and outputs.max() <= 0.5\n", - "\n", - "\n", - "def _test_torch_plugin(manager):\n", - " \"\"\"Test plugin in torch\"\"\"\n", - "\n", - " model = _get_torch_model(manager)\n", - " torch_data = torch.from_numpy(np.random.rand(1, 3, 224, 224).astype(\"float32\"))\n", - " if torch.cuda.is_available():\n", - " model = model.to(torch.device(\"cuda:0\"))\n", - " torch_data = torch_data.to(torch.device(\"cuda:0\"))\n", - " outputs = model(torch_data)\n", - " assert outputs.min() >= 0 and outputs.max() <= 0.5\n", - "\n", - "\n", - "def _test_with_manager(plugins, compile_type, expected_info):\n", - " \"\"\"Test the plugin with manager\"\"\"\n", - "\n", - " path = \"test_plugin_\" + compile_type\n", - " model = _get_torch_model(plugins[MSCFramework.TORCH])\n", - " if torch.cuda.is_available():\n", - " model = model.to(torch.device(\"cuda:0\"))\n", - " config = {\n", - " \"workspace\": msc_utils.msc_dir(path),\n", - " \"model_type\": MSCFramework.TORCH,\n", - " \"verbose\": \"critical\",\n", - " \"inputs\": [[\"input_0\", [1, 3, 224, 224], \"float32\"]],\n", - " \"outputs\": [\"output\"],\n", - " \"dataset\": {\"prepare\": {\"loader\": \"from_random\", \"max_iter\": 5}},\n", - " \"prepare\": {\"profile\": {\"benchmark\": {\"repeat\": 10}}},\n", - " \"baseline\": {\n", - " \"profile\": {\"check\": {\"atol\": 1e-2, \"rtol\": 1e-2}, \"benchmark\": {\"repeat\": 10}},\n", - " },\n", - " \"compile\": {\n", - " \"run_type\": compile_type,\n", - " \"profile\": {\"check\": {\"atol\": 1e-2, \"rtol\": 1e-2}, \"benchmark\": {\"repeat\": 10}},\n", - " },\n", - " }\n", - " manager = MSCManager(model, config, plugins=plugins)\n", - " report = manager.run_pipe()\n", - " model_info = manager.get_runtime().model_info\n", - " manager.destory()\n", - " assert report[\"success\"], \"Failed to run pipe for torch -> {}\".format(compile_type)\n", - " assert msc_utils.dict_equal(\n", - " model_info, expected_info\n", - " ), \"Model info {} mismatch with expected {}\".format(model_info, expected_info)\n", - "\n", - "\n", - "def test_plugin():\n", - " \"\"\"Test the plugins\"\"\"\n", - "\n", - " frameworks = [MSCFramework.TORCH, MSCFramework.TVM]\n", - " if tvm.get_global_func(\"relax.ext.tensorrt\", True) is not None:\n", - " frameworks.append(MSCFramework.TENSORRT)\n", - " plugin_root = msc_utils.msc_dir(\"msc_plugin\")\n", - " managers = _build_plugin(frameworks, plugin_root)\n", - "\n", - " # test the plugin load\n", - " _test_tvm_plugin(managers[MSCFramework.TVM], \"llvm\")\n", - " if tvm.cuda().exist:\n", - " _test_tvm_plugin(managers[MSCFramework.TVM], \"cuda\")\n", - " _test_torch_plugin(managers[MSCFramework.TORCH])\n", - "\n", - " # test the plugin with manager\n", - " model_info = {\n", - " \"inputs\": [\n", - " {\"name\": \"input_0\", \"shape\": [1, 3, 224, 224], \"dtype\": \"float32\", \"layout\": \"NCHW\"}\n", - " ],\n", - " \"outputs\": [\n", - " {\"name\": \"output\", \"shape\": [1, 6, 218, 218], \"dtype\": \"float32\", \"layout\": \"NCHW\"}\n", - " ],\n", - " \"nodes\": {\"total\": 4, \"input\": 1, \"msc.conv2d_bias\": 1, \"MyRelu\": 1, \"nn.max_pool2d\": 1},\n", - " }\n", - " _test_with_manager(managers, MSCFramework.TORCH, model_info)\n", - " _test_with_manager(managers, MSCFramework.TVM, model_info)\n", - " if tvm.get_global_func(\"relax.ext.tensorrt\", True) is not None:\n", - " byoc_info = {\n", - " \"inputs\": [\n", - " {\"name\": \"input_0\", \"shape\": [1, 3, 224, 224], \"dtype\": \"float32\", \"layout\": \"NCHW\"}\n", - " ],\n", - " \"outputs\": [\n", - " {\"name\": \"output\", \"shape\": [1, 6, 218, 218], \"dtype\": \"float32\", \"layout\": \"\"}\n", - " ],\n", - " \"nodes\": {\"total\": 2, \"input\": 1, \"msc_tensorrt\": 1},\n", - " }\n", - " _test_with_manager(managers, MSCFramework.TENSORRT, byoc_info)\n", - "\n", - " plugin_root.destory()\n", - "\n", - "\n", - "if __name__ == \"__main__\":\n", - " tvm.testing.main()\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "xxx", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/doc/tutorials/msc/plugin/index.md b/doc/tutorials/msc/plugin/index.md new file mode 100644 index 00000000..dbb30b13 --- /dev/null +++ b/doc/tutorials/msc/plugin/index.md @@ -0,0 +1,5 @@ +# 插件 + +```{toctree} +test +``` diff --git a/doc/tutorials/msc/plugin/msc_plugin/externs/externs.cc b/doc/tutorials/msc/plugin/msc_plugin/externs/externs.cc new file mode 100644 index 00000000..b97e6ed2 --- /dev/null +++ b/doc/tutorials/msc/plugin/msc_plugin/externs/externs.cc @@ -0,0 +1,29 @@ +#include "externs.h" + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +template +void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val) { + const T* input_data = input.const_data(); + T* output_data = output.data(); + for (size_t i = 0; i < output.size(); i++) { + if (input_data[i] >= max_val) { + output_data[i] = max_val; + } else if (input_data[i] <= 0) { + output_data[i] = 0; + } else { + output_data[i] = input_data[i]; + } + } +} + +template void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, + float max_val); + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/doc/tutorials/msc/plugin/msc_plugin/externs/externs.cu b/doc/tutorials/msc/plugin/msc_plugin/externs/externs.cu new file mode 100644 index 00000000..a4d816c0 --- /dev/null +++ b/doc/tutorials/msc/plugin/msc_plugin/externs/externs.cu @@ -0,0 +1,44 @@ +#include "externs.h" + +#define CU1DBLOCK 256 +#define KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +inline int n_blocks(int size, int block_size) { + return size / block_size + (size % block_size == 0 ? 0 : 1); +} + +template +__global__ static void _my_relu(const T* src, T* dst, T max_val, int n) { + KERNEL_LOOP(i, n) { + if (src[i] >= max_val) { + dst[i] = max_val; + } else if (src[i] <= 0) { + dst[i] = 0; + } else { + dst[i] = src[i]; + } + } +} + +template +void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, + const cudaStream_t& stream) { + const T* input_data = input.const_data(); + T* output_data = output.data(); + dim3 Bl(CU1DBLOCK); + dim3 Gr(n_blocks(output.size(), CU1DBLOCK)); + _my_relu<<>>(input_data, output_data, max_val, output.size()); +} + +template void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, + float max_val, const cudaStream_t& stream); + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/doc/tutorials/msc/plugin/msc_plugin/externs/externs.h b/doc/tutorials/msc/plugin/msc_plugin/externs/externs.h new file mode 100644 index 00000000..aeebacb0 --- /dev/null +++ b/doc/tutorials/msc/plugin/msc_plugin/externs/externs.h @@ -0,0 +1,47 @@ +#ifndef EXTERNS_H_ +#define EXTERNS_H_ + +#include "plugin_base.h" + +#ifdef PLUGIN_ENABLE_CUDA +#include +#endif + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +template +std::vector my_relu_infer(const std::vector& inputs, const TAttr& attrs, + bool is_runtime) { + std::vector outputs; + outputs.push_back(MetaTensor(inputs[0].shape(), inputs[0].data_type(), inputs[0].layout())); + return outputs; +} + +template +void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val); + +template +void my_relu_cpu_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs) { + my_relu_cpu_kernel(input, output, T(attrs.max_val)); +} + +#ifdef PLUGIN_ENABLE_CUDA +template +void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, + const cudaStream_t& stream); + +template +void my_relu_cuda_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs, + const cudaStream_t& stream) { + my_relu_cuda_kernel(input, output, T(attrs.max_val), stream); +} +#endif + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // EXTERNS_H_ diff --git a/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/MyRelu_attr.h b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/MyRelu_attr.h new file mode 100644 index 00000000..cfc9b7fe --- /dev/null +++ b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/MyRelu_attr.h @@ -0,0 +1,34 @@ +#ifndef TVM_CONTRIB_MSC_MYRELU_ATTR_H_ +#define TVM_CONTRIB_MSC_MYRELU_ATTR_H_ + +#include "plugin_utils.h" + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +struct MyReluMetaAttr { + // define attributes + float max_val; + + // print method + friend std::ostream& operator<<(std::ostream& out, const MyReluMetaAttr& attrs) { + out << "[MyReluMetaAttr] : "; + out << "| max_val(float)=" << attrs.max_val; + return out; + } + +}; // struct MyReluMetaAttr + +// serialize method +std::vector MyReluMetaAttr_serialize(const MyReluMetaAttr& meta_attr); + +// deserialize method +void MyReluMetaAttr_deserialize(const std::vector& attrs, MyReluMetaAttr& meta_attr); + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_MYRELU_ATTR_H_ \ No newline at end of file diff --git a/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/MyRelu_op.h b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/MyRelu_op.h new file mode 100644 index 00000000..793d400b --- /dev/null +++ b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/MyRelu_op.h @@ -0,0 +1,34 @@ +#ifndef TVM_CONTRIB_MSC_MYRELU_OP_H_ +#define TVM_CONTRIB_MSC_MYRELU_OP_H_ + +#include "MyRelu_attr.h" +#include "externs.h" + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +struct MyRelu : torch::CustomClassHolder { + MyRelu(const std::vector& attrs); + + // serialize method + const std::vector serialize(); + + // main compute + std::vector compute(const std::vector& input_tensors); + + // members + MyReluMetaAttr meta_attr_; + std::vector layouts_; + std::string name_; +}; // struct MyRelu : torch::CustomClassHolder + +// Entry method for plugin MyRelu +std::vector my_relu_entry(const c10::intrusive_ptr& instance, const torch::Tensor& input, const double& max_val, const std::string& name); + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_MYRELU_OP_H_ \ No newline at end of file diff --git a/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/externs.h b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/externs.h new file mode 100644 index 00000000..aeebacb0 --- /dev/null +++ b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/externs.h @@ -0,0 +1,47 @@ +#ifndef EXTERNS_H_ +#define EXTERNS_H_ + +#include "plugin_base.h" + +#ifdef PLUGIN_ENABLE_CUDA +#include +#endif + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +template +std::vector my_relu_infer(const std::vector& inputs, const TAttr& attrs, + bool is_runtime) { + std::vector outputs; + outputs.push_back(MetaTensor(inputs[0].shape(), inputs[0].data_type(), inputs[0].layout())); + return outputs; +} + +template +void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val); + +template +void my_relu_cpu_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs) { + my_relu_cpu_kernel(input, output, T(attrs.max_val)); +} + +#ifdef PLUGIN_ENABLE_CUDA +template +void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, + const cudaStream_t& stream); + +template +void my_relu_cuda_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs, + const cudaStream_t& stream) { + my_relu_cuda_kernel(input, output, T(attrs.max_val), stream); +} +#endif + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // EXTERNS_H_ diff --git a/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/plugin_base.h b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/plugin_base.h new file mode 100644 index 00000000..7c0353ff --- /dev/null +++ b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/plugin_base.h @@ -0,0 +1,284 @@ +#ifndef TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ +#define TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +typedef enum { + kUINT8 = 0, + kINT8 = 1, + kINT16 = 2, + kINT32 = 3, + kINT64 = 4, + kFLOAT16 = 5, + kFLOAT32 = 6, + kFLOAT64 = 7, + kUNKNOWN = 8, +} MetaDataType; + +class MetaShape { + public: + MetaShape() { shape_.resize(0); } + + MetaShape(const std::vector& shape) { + for (auto d : shape) { + shape_.push_back(d); + } + } + + template + void SetShape(const std::vector& shape) { + for (auto d : shape) { + shape_.push_back(static_cast(d)); + } + } + + template + void SetDim(int index, T dim) { + int valid_index = index < 0 ? shape_.size() + index : index; + if (valid_index >= shape_.size()) { + std::string err = + std::to_string(index) + " out of dims size " + std::to_string(shape_.size()); + throw std::runtime_error(err); + } + shape_[valid_index] = dim; + } + + template + const std::vector GetShape() const { + std::vector shape; + for (auto d : shape_) { + shape.push_back(d); + } + return shape; + } + + inline int64_t DimAt(int index) const { + int valid_index = index < 0 ? shape_.size() + index : index; + if (valid_index >= shape_.size()) { + std::string err = + std::to_string(index) + " out of dims size " + std::to_string(shape_.size()); + throw std::runtime_error(err); + } + return shape_[valid_index]; + } + + inline size_t ndim() const { return shape_.size(); } + + inline const std::vector shape() const { return shape_; } + + inline size_t size() const { + size_t size = 1; + for (auto d : shape_) { + assert(d > 0 && "Can not compute static size with unknow dim"); + size *= d; + } + return size; + } + + inline int64_t operator[](int index) const { return DimAt(index); } + + friend std::ostream& operator<<(std::ostream& out, const MetaShape& shape) { + for (size_t i = 0; i < shape.ndim(); i++) { + out << shape.DimAt(i) << (1 < shape.ndim() ? "" : ","); + } + return out; + } + + private: + std::vector shape_; +}; + +class MetaLayoutAxis { + public: + MetaLayoutAxis(const char name, size_t factor = 0) : factor_(factor) { + name_ = (factor == 0 ? "" : std::to_string(factor)) + std::string(1, name); + } + + MetaLayoutAxis(const std::string& name) { + if (name.size() == 1) { + factor_ = 0; + name_ = name; + } else { + factor_ = std::stoi(name.substr(1)); + name_ = name.substr(0, 1); + } + } + + inline const std::string name() const { return name_; } + + inline size_t factor() const { return factor_; } + + private: + std::string name_; + size_t factor_; +}; + +class MetaLayout { + public: + MetaLayout() {} + + MetaLayout(const std::string& name) : name_(name) { + int factor = 0; + for (char c : name) { + if (c >= 'A' && c <= 'Z') { + assert(factor == 0 && "Upper layout axis do not accept factor"); + MetaLayoutAxis axis(c); + axes_.push_back(axis); + } else if (c >= 'a' && c <= 'z') { + assert(factor > 0 && "Lower layout axis should has factor"); + MetaLayoutAxis axis(c, factor); + axes_.push_back(axis); + factor = 0; + } else if (c >= '0' && c <= '9') { + assert(factor >= 0 && "Factor number should between 0 and 9"); + factor = factor * 10 + c - '0'; + } else { + throw std::runtime_error("Unexpected layout axis " + name); + } + } + CheckValid(); + } + + MetaLayout(const std::vector& axes) : axes_(axes) { + name_ = ""; + for (auto a : axes_) { + name_ += (a.factor() == 0 ? "" : std::to_string(a.factor())) + a.name(); + } + CheckValid(); + }; + + void CheckValid() { + std::set recorded_axes; + for (auto a : axes_) { + auto axis_name = a.name(); + assert(!recorded_axes.count(axis_name) && ("Has duplicate layout axis in " + name_).c_str()); + recorded_axes.insert(axis_name); + } + } + + inline const MetaLayoutAxis AxisAt(int index) const { + int valid_index = index < 0 ? axes_.size() + index : index; + if (valid_index >= axes_.size()) { + std::string err = std::to_string(index) + " out of axes size " + std::to_string(axes_.size()); + throw std::runtime_error(err); + } + return axes_[valid_index]; + } + + inline MetaLayoutAxis operator[](int index) { return AxisAt(index); } + + inline size_t ndim() const { return axes_.size(); } + + inline std::string name() const { return name_; } + + friend std::ostream& operator<<(std::ostream& out, const MetaLayout& layout) { + out << layout.name(); + return out; + } + + private: + std::string name_; + std::vector axes_; +}; + +class MetaTensor { + public: + MetaTensor() {} + + MetaTensor(const MetaShape& shape, const MetaDataType& data_type, + const MetaLayout& layout = MetaLayout()) + : shape_(shape), data_type_(data_type), layout_(layout) {} + + inline const MetaShape shape() const { return shape_; } + + inline MetaDataType data_type() const { return data_type_; } + + inline const std::vector meta_shape() const { return shape_.shape(); } + + inline const MetaLayout layout() const { return layout_; } + + inline const std::string layout_name() const { return layout_.name(); } + + inline size_t ndim() const { return shape_.ndim(); } + + inline size_t size(bool count_batch = true) const { + if (count_batch) { + size_t batch_dim = 0; + for (size_t i = 0; i < layout_.ndim(); i++) { + if (layout_.AxisAt(i).name() == "N") { + batch_dim = i; + } + } + return shape_.size() / shape_.shape()[batch_dim]; + } + return shape_.size(); + } + + inline MetaLayoutAxis AxisAt(int index) const { return layout_.AxisAt(index); } + + inline int AxisOf(const std::string& axis) const { + for (size_t i = 0; i < layout_.ndim(); i++) { + if (layout_.AxisAt(i).name() == axis) { + return i; + } + } + return -1; + } + + inline int64_t DimAt(int index) const { return shape_.DimAt(index); } + + inline int64_t DimAt(const std::string& axis) const { + int idx = AxisOf(axis); + if (idx >= 0) { + return shape_.DimAt(idx); + } + throw std::runtime_error("Can not find dim for " + axis); + } + + friend std::ostream& operator<<(std::ostream& out, const MetaTensor& tensor) { + out << "tensor : <" << tensor.shape() << ">, (" << tensor.layout() << ")"; + return out; + } + + private: + MetaShape shape_; + MetaDataType data_type_; + MetaLayout layout_; +}; + +template +class DataTensor : public MetaTensor { + public: + DataTensor(const MetaShape shape, const MetaDataType& data_type, const MetaLayout layout, T* data) + : MetaTensor(shape, data_type, layout) { + data_ = data; + } + + DataTensor(const MetaShape shape, const MetaDataType& data_type, const MetaLayout layout, + const T* data) + : MetaTensor(shape, data_type, layout) { + data_ = const_cast(data); + } + + T* data() const { return data_; } + + const T* const_data() const { return data_; } + + private: + T* data_{nullptr}; +}; + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ diff --git a/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/plugin_utils.h b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/plugin_utils.h new file mode 100644 index 00000000..aa43d9e1 --- /dev/null +++ b/doc/tutorials/msc/plugin/msc_plugin/install/torch/include/plugin_utils.h @@ -0,0 +1,764 @@ +#ifndef TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ +#define TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "plugin_base.h" + +#ifdef PLUGIN_ENABLE_CUDA +#include +#include +#endif // PLUGIN_ENABLE_CUDA + +#ifdef PLUGIN_SUPPORT_TVM +#include + +#include "tvm/../../src/contrib/msc/core/transform/layout_utils.h" +#include "tvm/../../src/contrib/msc/core/utils.h" +#ifdef PLUGIN_ENABLE_CUDA +#include "tvm/../../src/runtime/cuda/cuda_common.h" +#endif // PLUGIN_ENABLE_CUDA +#endif // PLUGIN_SUPPORT_TVM + +#ifdef PLUGIN_SUPPORT_TORCH +#include +#include +#ifdef PLUGIN_ENABLE_CUDA +#include +#endif // PLUGIN_ENABLE_CUDA +#endif // PLUGIN_SUPPORT_TORCH + +#ifdef PLUGIN_SUPPORT_TENSORRT +#include "NvInfer.h" +#endif // PLUGIN_SUPPORT_TENSORRT + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +class SerializeUtils { + public: + // Helper function for serializing plugin attrs + template + static const std::string ToString(const T& value) { + return std::to_string(value); + } + + static std::string ToString(const std::string& value) { return value; } + + template + static std::string ToString(const std::vector& value) { + std::string str = std::to_string(value.size()); + for (const auto& v : value) { + str += "," + std::to_string(v); + } + return str; + } + + static void FromString(const std::string& src, std::string& target) { target = src; } + + static void FromString(const std::string& src, bool& target) { + target = std::stoi(src) > 0 ? true : false; + } + + static void FromString(const std::string& src, int& target) { target = std::stoi(src); } + + static void FromString(const std::string& src, size_t& target) { target = std::stoi(src); } + + static void FromString(const std::string& src, long& target) { target = std::stol(src); } + + static void FromString(const std::string& src, float& target) { target = std::stod(src); } + + static void FromString(const std::string& src, double& target) { target = std::stof(src); } + + template + static void FromString(const std::string& src, std::vector& target) { + std::string left_str = src; + int pos = left_str.find(","); + if (pos == std::string::npos) { + return; + } + assert(pos > 0); + size_t src_size; + FromString(left_str.substr(0, pos), src_size); + target.resize(src_size); + for (size_t i = 0; i < src_size; i++) { + pos = left_str.find(","); + left_str = left_str.substr(pos + 1); + FromString(left_str, target[i]); + } + } + + static void FromString(const std::string& src, std::vector& target) { + std::vector values; + FromString(src, values); + target.resize(values.size()); + for (size_t i = 0; i < values.size(); i++) { + target[i] = values[i] > 0 ? true : false; + } + } +}; + +class DataUtils { + public: + static MetaDataType ToMetaType(const std::string& name) { + MetaDataType dtype; + if (name == "int8") { + dtype = MetaDataType::kINT8; + } else if (name == "uint8" || name == "char") { + dtype = MetaDataType::kUINT8; + } else if (name == "int16") { + dtype = MetaDataType::kINT16; + } else if (name == "int32" || name == "int") { + dtype = MetaDataType::kINT32; + } else if (name == "int64" || name == "long") { + dtype = MetaDataType::kINT64; + } else if (name == "float16" || name == "half") { + dtype = MetaDataType::kFLOAT16; + } else if (name == "float32" || name == "float") { + dtype = MetaDataType::kFLOAT32; + } else if (name == "float64" || name == "double") { + dtype = MetaDataType::kFLOAT64; + } else { + dtype = MetaDataType::kUNKNOWN; + } + return dtype; + } + + static bool IsListType(const std::string& dtype) { + int pos = dtype.find("list("); + return pos == 0; + } + + static const std::string GetEleType(const std::string& dtype) { + int pos = dtype.find("list("); + if (pos == 0) { + return dtype.substr(pos + 5, dtype.size() - 6); + } + return ""; + } +}; + +#ifdef PLUGIN_SUPPORT_TVM +using namespace tvm::relax; +using namespace tvm::runtime; +class TVMUtils { + public: + static void AttrFromPrim(const PrimValue& expr, std::string& target) { + ICHECK(expr->IsInstance()) << "Expr is not StringImm"; + target = Downcast(expr)->value; + } + + static void AttrFromPrim(const PrimValue& expr, bool& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, int& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, size_t& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, long& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, float& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not FloatImm"; + target = Downcast(expr->value)->value; + } + + static void AttrFromPrim(const PrimValue& expr, double& target) { + ICHECK(expr->value->IsInstance()) << "Expr value is not FloatImm"; + target = Downcast(expr->value)->value; + } + + template + static void AttrFromPrims(const Tuple& tuple, std::vector& target) { + for (size_t i = 0; i < tuple->fields.size(); i++) { + ICHECK(tuple->fields[i]->IsInstance()) << "Field is not PrimValue"; + AttrFromPrim(Downcast(tuple->fields[i]), target[i]); + } + } + + static void AttrFromArg(const TVMArgValue& arg, std::string& target) { + target = arg.operator std::string(); + } + + static void AttrFromArg(const TVMArgValue& arg, bool& target) { target = arg; } + + static void AttrFromArg(const TVMArgValue& arg, int& target) { target = arg; } + + static void AttrFromArg(const TVMArgValue& arg, size_t& target) { target = int(arg); } + + static void AttrFromArg(const TVMArgValue& arg, long& target) { target = int64_t(arg); } + + static void AttrFromArg(const TVMArgValue& arg, float& target) { target = double(arg); } + + static void AttrFromArg(const TVMArgValue& arg, double& target) { target = arg; } + + template + static void AttrFromArgs(const TVMArgs& args, size_t start, size_t num, std::vector& target) { + for (size_t i = 0; i < num; i++) { + AttrFromArg(args[start + i], target[i]); + } + } + + static MetaDataType ToMetaType(const DataType& dtype) { + MetaDataType meta_type; + if (dtype.code() == 0 && dtype.bits() == 8) { + meta_type = MetaDataType::kINT8; + } else if (dtype.code() == 0 && dtype.bits() == 16) { + meta_type = MetaDataType::kINT16; + } else if (dtype.code() == 0 && dtype.bits() == 32) { + meta_type = MetaDataType::kINT32; + } else if (dtype.code() == 0 && dtype.bits() == 64) { + meta_type = MetaDataType::kINT64; + } else if (dtype.code() == 1 && dtype.bits() == 8) { + meta_type = MetaDataType::kUINT8; + } else if (dtype.code() == 2 && dtype.bits() == 16) { + meta_type = MetaDataType::kFLOAT16; + } else if (dtype.code() == 2 && dtype.bits() == 32) { + meta_type = MetaDataType::kFLOAT32; + } else if (dtype.code() == 2 && dtype.bits() == 64) { + meta_type = MetaDataType::kFLOAT64; + } else { + meta_type = MetaDataType::kUNKNOWN; + } + return meta_type; + } + + static MetaDataType ToMetaType(const DLDataType& dtype) { + MetaDataType meta_type; + if (dtype.code == 0U && dtype.bits == 8) { + meta_type = MetaDataType::kINT8; + } else if (dtype.code == 0U && dtype.bits == 16) { + meta_type = MetaDataType::kINT16; + } else if (dtype.code == 0U && dtype.bits == 32) { + meta_type = MetaDataType::kINT32; + } else if (dtype.code == 0U && dtype.bits == 64) { + meta_type = MetaDataType::kINT64; + } else if (dtype.code == 1U && dtype.bits == 8) { + meta_type = MetaDataType::kUINT8; + } else if (dtype.code == 2U && dtype.bits == 16) { + meta_type = MetaDataType::kFLOAT16; + } else if (dtype.code == 2U && dtype.bits == 32) { + meta_type = MetaDataType::kFLOAT32; + } else if (dtype.code == 2U && dtype.bits == 64) { + meta_type = MetaDataType::kFLOAT64; + } else { + meta_type = MetaDataType::kUNKNOWN; + } + return meta_type; + } + + static MetaShape ToMetaShape(const Optional>& tvm_shape) { + if (tvm_shape.defined()) { + std::vector shape_data; + for (auto s : tvm_shape.value()) { + if (s->IsInstance()) { + shape_data.push_back(Downcast(s)->value); + } else { + shape_data.push_back(-1); + } + } + return MetaShape(shape_data); + } + return MetaShape(); + } + + static MetaShape ToMetaShape(DLTensor* tensor, bool as_data = true) { + std::vector dims; + if (as_data) { + assert(tensor->ndim == 1); + assert(TVMUtils::ToMetaType(tensor->dtype) == MetaDataType::kINT64); + int64_t* data_ptr = (int64_t*)tensor->data; + for (size_t i = 0; i < tensor->shape[0]; i++) { + dims.push_back(data_ptr[i]); + } + } else { + for (size_t i = 0; i < tensor->ndim; i++) { + dims.push_back(tensor->shape[i]); + } + } + return MetaShape(dims); + } + + static MetaTensor ToMetaTensor(const Expr& expr, + const LayoutDecision& layout_dec = LayoutDecision()) { + const auto* sinfo = GetStructInfoAs(expr); + if (layout_dec.defined() && layout_dec->layout.defined()) { + const auto& layout = MetaLayout(layout_dec->layout.name()); + return MetaTensor(ToMetaShape(sinfo->GetShape()), ToMetaType(sinfo->dtype), layout); + } + const auto& layout = MetaLayout(SpanUtils::GetAttr(expr->span, "layout")); + return MetaTensor(ToMetaShape(sinfo->GetShape()), ToMetaType(sinfo->dtype), layout); + } + + template + static DataTensor ToDataTensor(DLTensor* tensor, bool read_only) { + if (read_only) { + return DataTensor(ToMetaShape(tensor, false), ToMetaType(tensor->dtype), MetaLayout(), + (const T*)(tensor->data)); + } else { + return DataTensor(ToMetaShape(tensor, false), ToMetaType(tensor->dtype), MetaLayout(), + (T*)(tensor->data)); + } + } + + static DataType ToTVMType(const MetaDataType& dtype) { + DataType tvm_type; + if (dtype == MetaDataType::kINT8) { + tvm_type = DataType::Int(8); + } else if (dtype == MetaDataType::kINT16) { + tvm_type = DataType::Int(16); + } else if (dtype == MetaDataType::kINT32) { + tvm_type = DataType::Int(32); + } else if (dtype == MetaDataType::kINT64) { + tvm_type = DataType::Int(64); + } else if (dtype == MetaDataType::kFLOAT16) { + tvm_type = DataType::Float(16); + } else if (dtype == MetaDataType::kFLOAT32) { + tvm_type = DataType::Float(32); + } else if (dtype == MetaDataType::kFLOAT64) { + tvm_type = DataType::Float(64); + } else { + throw std::runtime_error("Unsupported type"); + } + return tvm_type; + } + + static DataType ToTVMType(const std::string& dtype) { + return ToTVMType(DataUtils::ToMetaType(dtype)); + } + + static Array ToTVMShape(const MetaShape& meta_shape) { + Array tvm_shape; + for (size_t i = 0; i < meta_shape.ndim(); i++) { + auto dim = meta_shape.DimAt(i); + if (dim == -1) { + tvm_shape.push_back(tir::Any()); + } else { + tvm_shape.push_back(Integer(dim)); + } + } + return tvm_shape; + } + + static void FillDLShape(const MetaShape& shape, DLTensor* data) { + auto shape_data = static_cast(data->data); + for (size_t i = 0; i < shape.ndim(); i++) { + shape_data[i] = shape.DimAt(i); + } + } + + static TensorStructInfo ToTensorStructInfo(const MetaTensor& tensor, + const Optional& device) { + const auto& t_shape = ToTVMShape(tensor.shape()); + const auto& t_type = ToTVMType(tensor.data_type()); + return TensorStructInfo(ShapeExpr(t_shape), t_type, device); + } + + static TensorStructInfo ToTensorStructInfo(const MetaTensor& tensor, const Expr& expr) { + const auto* sinfo = GetStructInfoAs(expr); + return ToTensorStructInfo(tensor, sinfo->vdevice); + } + + static bool OnDevice(DLTensor* tensor, DLDeviceType device) { + return tensor->device.device_type == device; + } + + static void CheckDevice(DLTensor* tensor, DLDeviceType device) { + ICHECK_EQ(tensor->device.device_type, device); + } + + static Device DefaultCPU() { + Device cpu_dev{kDLCPU, 0}; + return cpu_dev; + } + + static Device DefaultCUDA() { + Device cuda_dev{kDLCUDA, 0}; + return cuda_dev; + } +}; +#endif // PLUGIN_SUPPORT_TVM + +#ifdef PLUGIN_SUPPORT_TORCH +class TorchUtils { + public: + static MetaDataType ToMetaType(const torch::ScalarType& dtype) { + MetaDataType meta_type; + if (dtype == torch::kChar) { + meta_type = MetaDataType::kINT8; + } else if (dtype == torch::kInt) { + meta_type = MetaDataType::kINT32; + } else if (dtype == torch::kInt64) { + meta_type = MetaDataType::kINT64; + } else if (dtype == torch::kLong) { + meta_type = MetaDataType::kINT64; + } else if (dtype == torch::kFloat16) { + meta_type = MetaDataType::kFLOAT16; + } else if (dtype == torch::kFloat) { + meta_type = MetaDataType::kFLOAT32; + } else if (dtype == torch::kDouble) { + meta_type = MetaDataType::kFLOAT64; + } else { + meta_type = MetaDataType::kUNKNOWN; + } + return meta_type; + } + + static MetaShape ToMetaShape(const torch::Tensor& tensor) { + std::vector shape_data; + for (size_t idx = 0; idx < tensor.dim(); idx++) { + shape_data.push_back(tensor.size(idx)); + } + return MetaShape(shape_data); + } + + static MetaTensor ToMetaTensor(const torch::Tensor& tensor, + const MetaLayout& layout = MetaLayout()) { + return MetaTensor(ToMetaShape(tensor), ToMetaType(tensor.scalar_type()), layout); + } + + template + static DataTensor ToDataTensor(const torch::Tensor& tensor, const MetaTensor& meta, + bool read_only) { + if (read_only) { + return DataTensor(meta.shape(), meta.data_type(), meta.layout(), + (const T*)(tensor.data_ptr())); + } else { + return DataTensor(meta.shape(), meta.data_type(), meta.layout(), (T*)(tensor.data_ptr())); + } + } + + static torch::ScalarType ToTorchType(const MetaDataType& dtype) { + torch::ScalarType torch_type; + if (dtype == MetaDataType::kINT8) { + torch_type = torch::kChar; + } else if (dtype == MetaDataType::kINT32) { + torch_type = torch::kInt; + } else if (dtype == MetaDataType::kINT64) { + torch_type = torch::kInt64; + } else if (dtype == MetaDataType::kFLOAT16) { + torch_type = torch::kFloat16; + } else if (dtype == MetaDataType::kFLOAT32) { + torch_type = torch::kFloat; + } else if (dtype == MetaDataType::kFLOAT64) { + torch_type = torch::kDouble; + } else { + throw std::runtime_error("Unsupported type"); + } + return torch_type; + } + + static torch::ScalarType ToTorchType(const std::string& dtype) { + return ToTorchType(DataUtils::ToMetaType(dtype)); + } + + static torch::Device ToTorchDevice(const std::string& device) { + if (device == "cpu") { + return torch::Device(torch::kCPU); + } + if (device == "cuda") { + return torch::Device(torch::kCUDA); + } + return torch::Device(torch::kCPU); + } + + static torch::Tensor MallocTorchTensor(const MetaTensor& tensor, const torch::Device& device) { + auto t_type = ToTorchType(tensor.data_type()); + auto opt = torch::TensorOptions().dtype(t_type).device(device); + return torch::zeros(tensor.meta_shape(), opt); + } +}; +#endif // PLUGIN_SUPPORT_TORCH + +#ifdef PLUGIN_SUPPORT_TENSORRT +using namespace nvinfer1; + +#ifndef TRT_VERSION_GE +#define TRT_VERSION_GE(major, minor, patch) \ + ((TRT_MAJOR > major) || (TRT_MAJOR == major && TRT_MINOR > minor) || \ + (TRT_MAJOR == major && TRT_MINOR == minor && TRT_PATCH >= patch)) +#endif + +class TRTUtils { + public: + template + static void ValToBuffer(char*& buffer, const T& val) { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); + } + + static void ValToBuffer(char*& buffer, const std::string& val) { + *reinterpret_cast(buffer) = val.size(); + buffer += sizeof(size_t); + val.copy(buffer, val.size()); + buffer += sizeof(char) * val.size(); + } + + template + static void ValToBuffer(char*& buffer, const std::vector& val) { + ValToBuffer(buffer, val.size()); + for (auto e : val) { + ValToBuffer(buffer, e); + } + } + + template + static void ValFromBuffer(const char*& buffer, T& val) { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); + } + + static void ValFromBuffer(const char*& buffer, std::string& val) { + auto size = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + val = std::string(reinterpret_cast(buffer), size); + buffer += sizeof(char) * size; + } + + template + static void ValFromBuffer(const char*& buffer, std::vector& val) { + size_t size; + ValFromBuffer(buffer, size); + val.resize(size); + for (size_t i = 0; i < size; i++) { + ValFromBuffer(buffer, val[i]); + } + } + + static PluginFieldType ToFieldType(const std::string& dtype) { + PluginFieldType field_type; + if (dtype == "char" || dtype == "uint8" || dtype == "string") { + field_type = PluginFieldType::kCHAR; + } else if (dtype == "int8") { + field_type = PluginFieldType::kINT8; + } else if (dtype == "int16") { + field_type = PluginFieldType::kINT16; + } else if (dtype == "int" || dtype == "int32") { + field_type = PluginFieldType::kINT32; + } else if (dtype == "float16" || dtype == "half") { + field_type = PluginFieldType::kFLOAT16; + } else if (dtype == "float32" || dtype == "float") { + field_type = PluginFieldType::kFLOAT32; + } else if (dtype == "float64" || dtype == "double") { + field_type = PluginFieldType::kFLOAT64; + } else { + field_type = PluginFieldType::kUNKNOWN; + } + return field_type; + } + + static const PluginField ToField(const std::string& name, const std::string& dtype) { + const auto& ele_type = DataUtils::GetEleType(dtype); + if (ele_type.size() == 0) { + return PluginField(name.c_str(), nullptr, ToFieldType(dtype), 1); + } + return PluginField(name.c_str(), nullptr, ToFieldType(ele_type), 11); + } + + static void FromField(const PluginField& field, std::string& val) { + assert(field.type == PluginFieldType::kCHAR); + const char* data = static_cast(field.data); + val = data; + } + + static void FromField(const PluginField& field, bool& val) { + assert(field.type == PluginFieldType::kINT32); + int int_val = *(static_cast(field.data)); + val = int_val == 0 ? false : true; + } + + static void FromField(const PluginField& field, int& val) { + assert(field.type == PluginFieldType::kINT32); + val = *(static_cast(field.data)); + } + + static void FromField(const PluginField& field, size_t& val) { + assert(field.type == PluginFieldType::kINT32); + val = *(static_cast(field.data)); + } + + static void FromField(const PluginField& field, long& val) { + assert(field.type == PluginFieldType::kINT32); + val = *(static_cast(field.data)); + } + + static void FromField(const PluginField& field, float& val) { + assert(field.type == PluginFieldType::kFLOAT32); + val = *(static_cast(field.data)); + } + + static void FromField(const PluginField& field, double& val) { + assert(field.type == PluginFieldType::kFLOAT64); + val = *(static_cast(field.data)); + } + + static MetaDataType ToMetaType(const DataType& dtype) { + MetaDataType meta_type; + if (dtype == DataType::kINT8) { + meta_type = MetaDataType::kINT8; + } else if (dtype == DataType::kINT32) { + meta_type = MetaDataType::kINT32; + } else if (dtype == DataType::kHALF) { + meta_type = MetaDataType::kFLOAT16; + } else if (dtype == DataType::kFLOAT) { + meta_type = MetaDataType::kFLOAT32; + } else { + meta_type = MetaDataType::kUNKNOWN; + } + return meta_type; + } + + static MetaShape ToMetaShape(const Dims& trt_dims, bool dynamic = false) { + std::vector dims; + if (!dynamic) { + dims.push_back(1); + } + for (size_t idx = 0; idx < trt_dims.nbDims; idx++) { + dims.push_back(trt_dims.d[idx]); + } + return MetaShape(dims); + } + + static MetaShape ToMetaShape(const DimsExprs& trt_dims) { + std::vector dims; + for (size_t idx = 0; idx < trt_dims.nbDims; idx++) { + assert(trt_dims.d[idx]->isConstant()); + dims.push_back(trt_dims.d[idx]->getConstantValue()); + } + return MetaShape(dims); + } + + static MetaShape ToMetaShape(const PluginTensorDesc& desc) { + return ToMetaShape(desc.dims, true); + } + + static MetaShape ToMetaShape(const DynamicPluginTensorDesc& desc) { + return ToMetaShape(desc.desc); + } + + static MetaTensor ToMetaTensor(const Dims& dims, const DataType& dtype, const std::string& layout, + bool dynamic = false) { + return MetaTensor(ToMetaShape(dims, dynamic), ToMetaType(dtype), MetaLayout(layout)); + } + + static MetaTensor ToMetaTensor(const DimsExprs& dims, const DataType& dtype, + const std::string& layout) { + return MetaTensor(ToMetaShape(dims), ToMetaType(dtype), MetaLayout(layout)); + } + + static MetaTensor ToMetaTensor(const PluginTensorDesc& desc, const std::string& layout) { + return ToMetaTensor(desc.dims, desc.type, layout, true); + } + + static MetaTensor ToMetaTensor(const DynamicPluginTensorDesc& desc, const std::string& layout) { + return ToMetaTensor(desc.desc, layout); + } + + static DataType ToDataType(const MetaDataType& dtype) { + DataType data_type; + if (dtype == MetaDataType::kINT8) { + data_type = DataType::kINT8; + } else if (dtype == MetaDataType::kINT32) { + data_type = DataType::kINT32; + } else if (dtype == MetaDataType::kFLOAT16) { + data_type = DataType::kHALF; + } else if (dtype == MetaDataType::kFLOAT32) { + data_type = DataType::kFLOAT; + } else { + data_type = DataType::kFLOAT; + } + return data_type; + } + + static DataType ToDataType(const std::string& dtype) { + return ToDataType(DataUtils::ToMetaType(dtype)); + } + + static Dims ToDims(const MetaShape& meta_shape, bool dynamic = false) { + std::vector int_dims; + if (dynamic) { + int_dims.push_back(meta_shape.DimAt(0)); + } + for (size_t i = 1; i < meta_shape.ndim(); i++) { + int_dims.push_back(meta_shape.DimAt(i)); + } + Dims dims{int(int_dims.size())}; + for (size_t i = 0; i < int_dims.size(); i++) { + dims.d[i] = int_dims[i]; + } + return dims; + } + + static DimsExprs ToDimsExprs(const MetaShape& meta_shape, IExprBuilder& builder) { + std::vector int_dims; + for (size_t i = 0; i < meta_shape.ndim(); i++) { + int_dims.push_back(meta_shape.DimAt(i)); + } + DimsExprs dims{int(int_dims.size())}; + for (size_t i = 0; i < int_dims.size(); i++) { + dims.d[i] = builder.constant(int_dims[i]); + } + return dims; + } + + static const MetaShape SetBatch(const MetaTensor& tensor, int batch_size) { + MetaShape shape = tensor.shape(); + int batch = tensor.AxisOf("N"); + if (batch < 0) { + batch = 0; + } + shape.SetDim(batch, batch_size); + return shape; + } + + template + static DataTensor ToDataTensor(const MetaTensor& tensor, int batch_size, const void* data) { + const auto& shape = SetBatch(tensor, batch_size); + return DataTensor(shape, tensor.data_type(), tensor.layout(), (const T*)(data)); + } + + template + static DataTensor ToDataTensor(const MetaTensor& tensor, int batch_size, void* data) { + const auto& shape = SetBatch(tensor, batch_size); + return DataTensor(shape, tensor.data_type(), tensor.layout(), (const T*)(data)); + } + + template + static DataTensor ToDataTensor(const MetaTensor& tensor, const PluginTensorDesc& desc, + const void* data) { + return DataTensor(ToMetaShape(desc), ToMetaType(desc.type), tensor.layout(), + (const T*)(data)); + } + + template + static DataTensor ToDataTensor(const MetaTensor& tensor, const PluginTensorDesc& desc, + void* data) { + return DataTensor(ToMetaShape(desc), ToMetaType(desc.type), tensor.layout(), (T*)(data)); + } +}; +#endif // PLUGIN_SUPPORT_TENSORRT + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ diff --git a/doc/tutorials/msc/plugin/set_env.py b/doc/tutorials/msc/plugin/set_env.py new file mode 100644 index 00000000..798095c7 --- /dev/null +++ b/doc/tutorials/msc/plugin/set_env.py @@ -0,0 +1,7 @@ +import sys +from pathlib import Path +root_dir = Path(__file__).resolve().parents[4] +sys.path.extend([ + f"{root_dir}/tests" +]) +import env diff --git a/doc/tutorials/msc/plugin/test.ipynb b/doc/tutorials/msc/plugin/test.ipynb new file mode 100644 index 00000000..5cc30ecb --- /dev/null +++ b/doc/tutorials/msc/plugin/test.ipynb @@ -0,0 +1,132 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MSC Plugin" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import set_env\n", + "from pathlib import Path\n", + "\n", + "temp_dir = Path(\".temp\")\n", + "temp_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "参考:[【我与TVM二三事 后篇(5)】MSC之PluginBuilder](https://zhuanlan.zhihu.com/p/681450076)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from utils import *\n", + "from utils import _build_plugin" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "AssertionError", + "evalue": "Failed to build plugin under /media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc/plugin/msc_plugin/install/source_torch/build, check codegen.log for detail", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m frameworks\u001b[38;5;241m.\u001b[39mappend(MSCFramework\u001b[38;5;241m.\u001b[39mTENSORRT)\n\u001b[1;32m 4\u001b[0m plugin_root \u001b[38;5;241m=\u001b[39m msc_utils\u001b[38;5;241m.\u001b[39mmsc_dir(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmsc_plugin\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m managers \u001b[38;5;241m=\u001b[39m _build_plugin(frameworks, plugin_root)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# test the plugin load\u001b[39;00m\n\u001b[1;32m 8\u001b[0m _test_tvm_plugin(managers[MSCFramework\u001b[38;5;241m.\u001b[39mTVM], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllvm\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc/plugin/utils.py:231\u001b[0m, in \u001b[0;36m_build_plugin\u001b[0;34m(frameworks, plugin_root)\u001b[0m\n\u001b[1;32m 229\u001b[0m install_dir \u001b[38;5;241m=\u001b[39m plugin_root\u001b[38;5;241m.\u001b[39mcreate_dir(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minstall\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 230\u001b[0m plugin \u001b[38;5;241m=\u001b[39m _create_plugin(externs_dir)\n\u001b[0;32m--> 231\u001b[0m managers \u001b[38;5;241m=\u001b[39m build_plugins(plugin, frameworks, install_dir, externs_dir\u001b[38;5;241m=\u001b[39mexterns_dir)\n\u001b[1;32m 232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m managers\n", + "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/plugin/build.py:128\u001b[0m, in \u001b[0;36mbuild_plugins\u001b[0;34m(plugins, frameworks, workspace, codegen_config, cpp_print_config, py_print_config, externs_dir, on_debug)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbuild_plugins\u001b[39m(\n\u001b[1;32m 92\u001b[0m plugins: Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mdict\u001b[39m],\n\u001b[1;32m 93\u001b[0m frameworks: List[\u001b[38;5;28mstr\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 99\u001b[0m on_debug: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dict[\u001b[38;5;28mstr\u001b[39m, Any]:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Build the plugins and load plugin manager\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \n\u001b[1;32m 103\u001b[0m \u001b[38;5;124;03m Parameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;124;03m The plugin managers.\u001b[39;00m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 128\u001b[0m codegens \u001b[38;5;241m=\u001b[39m _build_plugins(\n\u001b[1;32m 129\u001b[0m plugins,\n\u001b[1;32m 130\u001b[0m frameworks,\n\u001b[1;32m 131\u001b[0m workspace,\n\u001b[1;32m 132\u001b[0m codegen_config\u001b[38;5;241m=\u001b[39mcodegen_config,\n\u001b[1;32m 133\u001b[0m cpp_print_config\u001b[38;5;241m=\u001b[39mcpp_print_config,\n\u001b[1;32m 134\u001b[0m py_print_config\u001b[38;5;241m=\u001b[39mpy_print_config,\n\u001b[1;32m 135\u001b[0m externs_dir\u001b[38;5;241m=\u001b[39mexterns_dir,\n\u001b[1;32m 136\u001b[0m on_debug\u001b[38;5;241m=\u001b[39mon_debug,\n\u001b[1;32m 137\u001b[0m )\n\u001b[1;32m 138\u001b[0m managers \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, codegen \u001b[38;5;129;01min\u001b[39;00m codegens\u001b[38;5;241m.\u001b[39mitems():\n", + "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/plugin/build.py:84\u001b[0m, in \u001b[0;36m_build_plugins\u001b[0;34m(plugins, frameworks, workspace, codegen_config, cpp_print_config, py_print_config, externs_dir, on_debug)\u001b[0m\n\u001b[1;32m 73\u001b[0m codegen \u001b[38;5;241m=\u001b[39m get_codegen(\n\u001b[1;32m 74\u001b[0m framework,\n\u001b[1;32m 75\u001b[0m workspace,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 81\u001b[0m on_debug\u001b[38;5;241m=\u001b[39mon_debug,\n\u001b[1;32m 82\u001b[0m )\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m codegen\u001b[38;5;241m.\u001b[39mlibs_built():\n\u001b[0;32m---> 84\u001b[0m codegen\u001b[38;5;241m.\u001b[39mbuild_libs()\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m codegen\u001b[38;5;241m.\u001b[39mneed_manager \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m codegen\u001b[38;5;241m.\u001b[39mmanager_built():\n\u001b[1;32m 86\u001b[0m codegen\u001b[38;5;241m.\u001b[39mbuild_manager(ops_info)\n", + "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/plugin/codegen/codegen.py:127\u001b[0m, in \u001b[0;36mBasePluginCodeGen.build_libs\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 125\u001b[0m process \u001b[38;5;241m=\u001b[39m subprocess\u001b[38;5;241m.\u001b[39mPopen(command, stdout\u001b[38;5;241m=\u001b[39mlog_f, stderr\u001b[38;5;241m=\u001b[39mlog_f, shell\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 126\u001b[0m process\u001b[38;5;241m.\u001b[39mwait()\n\u001b[0;32m--> 127\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[1;32m 128\u001b[0m process\u001b[38;5;241m.\u001b[39mreturncode \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 129\u001b[0m ), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to build plugin under \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m, check codegen.log for detail\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 130\u001b[0m os\u001b[38;5;241m.\u001b[39mgetcwd()\n\u001b[1;32m 131\u001b[0m )\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_libs\u001b[38;5;241m.\u001b[39mextend([os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mbasename(l) \u001b[38;5;28;01mfor\u001b[39;00m l \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lib_folder\u001b[38;5;241m.\u001b[39mlistdir()])\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lib_folder\u001b[38;5;241m.\u001b[39mlistdir(as_abs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mAssertionError\u001b[0m: Failed to build plugin under /media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc/plugin/msc_plugin/install/source_torch/build, check codegen.log for detail" + ] + } + ], + "source": [ + "frameworks = [MSCFramework.TORCH, MSCFramework.TVM]\n", + "if tvm.get_global_func(\"relax.ext.tensorrt\", True) is not None:\n", + " frameworks.append(MSCFramework.TENSORRT)\n", + "plugin_root = msc_utils.msc_dir(\"msc_plugin\")\n", + "managers = _build_plugin(frameworks, plugin_root)\n", + "\n", + "# test the plugin load\n", + "_test_tvm_plugin(managers[MSCFramework.TVM], \"llvm\")\n", + "if tvm.cuda().exist:\n", + " _test_tvm_plugin(managers[MSCFramework.TVM], \"cuda\")\n", + "_test_torch_plugin(managers[MSCFramework.TORCH])\n", + "\n", + "# test the plugin with manager\n", + "model_info = {\n", + " \"inputs\": [\n", + " {\"name\": \"input_0\", \"shape\": [1, 3, 224, 224], \"dtype\": \"float32\", \"layout\": \"NCHW\"}\n", + " ],\n", + " \"outputs\": [\n", + " {\"name\": \"output\", \"shape\": [1, 6, 218, 218], \"dtype\": \"float32\", \"layout\": \"NCHW\"}\n", + " ],\n", + " \"nodes\": {\"total\": 4, \"input\": 1, \"msc.conv2d_bias\": 1, \"MyRelu\": 1, \"nn.max_pool2d\": 1},\n", + "}\n", + "_test_with_manager(managers, MSCFramework.TORCH, model_info)\n", + "_test_with_manager(managers, MSCFramework.TVM, model_info)\n", + "if tvm.get_global_func(\"relax.ext.tensorrt\", True) is not None:\n", + " byoc_info = {\n", + " \"inputs\": [\n", + " {\"name\": \"input_0\", \"shape\": [1, 3, 224, 224], \"dtype\": \"float32\", \"layout\": \"NCHW\"}\n", + " ],\n", + " \"outputs\": [\n", + " {\"name\": \"output\", \"shape\": [1, 6, 218, 218], \"dtype\": \"float32\", \"layout\": \"\"}\n", + " ],\n", + " \"nodes\": {\"total\": 2, \"input\": 1, \"msc_tensorrt\": 1},\n", + " }\n", + " _test_with_manager(managers, MSCFramework.TENSORRT, byoc_info)\n", + "\n", + "plugin_root.destory()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xxx", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/tutorials/msc/plugin/utils.py b/doc/tutorials/msc/plugin/utils.py new file mode 100644 index 00000000..11484f00 --- /dev/null +++ b/doc/tutorials/msc/plugin/utils.py @@ -0,0 +1,301 @@ +import numpy as np + +import torch +from torch import nn + +import tvm.testing +from tvm import relax +from tvm.relax.transform import BindParams +from tvm.script import relax as R +from tvm.contrib.msc.pipeline import MSCManager +from tvm.contrib.msc.plugin import build_plugins +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + + +def _get_externs_header(): + """Get the header source for externs""" + + return """#ifndef EXTERNS_H_ +#define EXTERNS_H_ + +#include "plugin_base.h" + +#ifdef PLUGIN_ENABLE_CUDA +#include +#endif + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +template +std::vector my_relu_infer(const std::vector& inputs, const TAttr& attrs, + bool is_runtime) { + std::vector outputs; + outputs.push_back(MetaTensor(inputs[0].shape(), inputs[0].data_type(), inputs[0].layout())); + return outputs; +} + +template +void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val); + +template +void my_relu_cpu_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs) { + my_relu_cpu_kernel(input, output, T(attrs.max_val)); +} + +#ifdef PLUGIN_ENABLE_CUDA +template +void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, + const cudaStream_t& stream); + +template +void my_relu_cuda_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs, + const cudaStream_t& stream) { + my_relu_cuda_kernel(input, output, T(attrs.max_val), stream); +} +#endif + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // EXTERNS_H_ +""" + + +def _get_externs_cc(): + """Get externs cc source""" + return """#include "externs.h" + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +template +void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val) { + const T* input_data = input.const_data(); + T* output_data = output.data(); + for (size_t i = 0; i < output.size(); i++) { + if (input_data[i] >= max_val) { + output_data[i] = max_val; + } else if (input_data[i] <= 0) { + output_data[i] = 0; + } else { + output_data[i] = input_data[i]; + } + } +} + +template void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, + float max_val); + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +""" + + +def _get_externs_cu(): + """Get externs cu source""" + + return """#include "externs.h" + +#define CU1DBLOCK 256 +#define KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) + +namespace tvm { +namespace contrib { +namespace msc { +namespace plugin { + +inline int n_blocks(int size, int block_size) { + return size / block_size + (size % block_size == 0 ? 0 : 1); +} + +template +__global__ static void _my_relu(const T* src, T* dst, T max_val, int n) { + KERNEL_LOOP(i, n) { + if (src[i] >= max_val) { + dst[i] = max_val; + } else if (src[i] <= 0) { + dst[i] = 0; + } else { + dst[i] = src[i]; + } + } +} + +template +void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, + const cudaStream_t& stream) { + const T* input_data = input.const_data(); + T* output_data = output.data(); + dim3 Bl(CU1DBLOCK); + dim3 Gr(n_blocks(output.size(), CU1DBLOCK)); + _my_relu<<>>(input_data, output_data, max_val, output.size()); +} + +template void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, + float max_val, const cudaStream_t& stream); + +} // namespace plugin +} // namespace msc +} // namespace contrib +} // namespace tvm +""" + +def _create_plugin(externs_dir): + """Create sources under source folder""" + with open(externs_dir.relpath("externs.h"), "w") as f: + f.write(_get_externs_header()) + with open(externs_dir.relpath("externs.cc"), "w") as f: + f.write(_get_externs_cc()) + with open(externs_dir.relpath("externs.cu"), "w") as f: + f.write(_get_externs_cu()) + return { + "MyRelu": { + "inputs": [{"name": "input", "dtype": "T"}], + "outputs": [{"name": "output", "dtype": "T"}], + "attrs": [{"name": "max_val", "type": "float"}], + "support_dtypes": {"T": ["float"]}, + "externs": { + "infer_output": {"name": "my_relu_infer", "header": "externs.h"}, + "cpu_compute": { + "name": "my_relu_cpu_compute", + "header": "externs.h", + "source": "externs.cc", + }, + "cuda_compute": { + "name": "my_relu_cuda_compute", + "header": "externs.h", + "source": "externs.cu", + }, + }, + } + } + + +def _get_torch_model(torch_manager): + """Build model with plugin""" + + class MyModel(nn.Module): + """Test model with plugin""" + + def __init__(self): + super(MyModel, self).__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + self.relu = torch_manager.MyRelu(max_val=0.5) + self.maxpool = nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, data): + data = self.conv(data) + data = self.relu(data) + return self.maxpool(data) + + return MyModel() + + +def _get_tvm_model(tvm_manager): + """Build model with plugin""" + + block_builder = relax.BlockBuilder() + weights = np.random.rand(6, 3, 7, 7).astype("float32") + data = relax.Var("data", R.Tensor((1, 3, 224, 224), "float32")) + weight = relax.Var("weight", R.Tensor(weights.shape, weights.dtype.name)) + inputs = [data, weight] + with block_builder.function(name="main", params=inputs.copy()): + with block_builder.dataflow(): + data = relax.op.nn.conv2d(data, weight) + data = block_builder.emit(data, "conv2d") + data = tvm_manager.MyRelu(data, max_val=0.5) + data = block_builder.emit(data, "relu") + data = relax.op.nn.max_pool2d(data) + data = block_builder.emit(data, "max_pool2d") + data = block_builder.emit_output(data) + block_builder.emit_func_output(data) + mod = block_builder.finalize() + return BindParams("main", {"weight": tvm.nd.array(weights)})(mod) + + +def _build_plugin(frameworks, plugin_root): + externs_dir = plugin_root.create_dir("externs") + install_dir = plugin_root.create_dir("install") + plugin = _create_plugin(externs_dir) + managers = build_plugins(plugin, frameworks, install_dir, externs_dir=externs_dir) + return managers + + +def _run_relax(relax_mod, target_name, data): + target = tvm.target.Target(target_name) + relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod) + if target_name == "cuda": + with target: + relax_mod = tvm.tir.transform.DefaultGPUSchedule()(relax_mod) + device = tvm.cuda() + else: + device = tvm.cpu() + with tvm.transform.PassContext(opt_level=3): + relax_exec = tvm.relax.build(relax_mod, target) + runnable = tvm.relax.VirtualMachine(relax_exec, device) + data = tvm.nd.array(data, device) + return runnable["main"](data).asnumpy() + + +def _test_torch_plugin(manager): + """Test plugin in torch""" + + model = _get_torch_model(manager) + torch_data = torch.from_numpy(np.random.rand(1, 3, 224, 224).astype("float32")) + if torch.cuda.is_available(): + model = model.to(torch.device("cuda:0")) + torch_data = torch_data.to(torch.device("cuda:0")) + outputs = model(torch_data) + assert outputs.min() >= 0 and outputs.max() <= 0.5 + + +def _test_with_manager(plugins, compile_type, expected_info): + """Test the plugin with manager""" + + path = "test_plugin_" + compile_type + model = _get_torch_model(plugins[MSCFramework.TORCH]) + if torch.cuda.is_available(): + model = model.to(torch.device("cuda:0")) + config = { + "workspace": msc_utils.msc_dir(path), + "model_type": MSCFramework.TORCH, + "verbose": "critical", + "inputs": [["input_0", [1, 3, 224, 224], "float32"]], + "outputs": ["output"], + "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, + "prepare": {"profile": {"benchmark": {"repeat": 10}}}, + "baseline": { + "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}}, + }, + "compile": { + "run_type": compile_type, + "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}}, + }, + } + manager = MSCManager(model, config, plugins=plugins) + report = manager.run_pipe() + model_info = manager.get_runtime().model_info + manager.destory() + assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) + assert msc_utils.dict_equal( + model_info, expected_info + ), "Model info {} mismatch with expected {}".format(model_info, expected_info) + +def _test_tvm_plugin(manager, target): + """Test plugin in tvm""" + + model = _get_tvm_model(manager) + data = np.random.rand(1, 3, 224, 224).astype("float32") + outputs = _run_relax(model, target, data) + assert outputs.min() >= 0 and outputs.max() <= 0.5 diff --git a/doc/tutorials/msc/set_env.py b/doc/tutorials/msc/set_env.py index 69d740ba..59faec36 100644 --- a/doc/tutorials/msc/set_env.py +++ b/doc/tutorials/msc/set_env.py @@ -1,6 +1,6 @@ import sys from pathlib import Path -root_dir = Path("__file__").resolve().parents[3] +root_dir = Path(__file__).resolve().parents[3] sys.path.extend([ f"{root_dir}/tests" ]) diff --git a/doc/tutorials/msc/tests/index.md b/doc/tutorials/msc/tests/index.md new file mode 100644 index 00000000..125d773c --- /dev/null +++ b/doc/tutorials/msc/tests/index.md @@ -0,0 +1,7 @@ +# 测试 + +```{toctree} +:glob: + +* +``` diff --git a/doc/test/tutorials/msc/transform.ipynb b/doc/tutorials/msc/tests/transform.ipynb similarity index 97% rename from doc/test/tutorials/msc/transform.ipynb rename to doc/tutorials/msc/tests/transform.ipynb index a3681e28..3380ed37 100644 --- a/doc/test/tutorials/msc/transform.ipynb +++ b/doc/tutorials/msc/tests/transform.ipynb @@ -182,15 +182,6 @@ "mod = msc_transform.SetExprName()(mod)\n", "RelaxNameChecker().check(mod)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from torchvision.models import vgg" - ] } ], "metadata": { diff --git a/doc/test/tutorials/msc/translate-relax.ipynb b/doc/tutorials/msc/tests/translate-relax.ipynb similarity index 100% rename from doc/test/tutorials/msc/translate-relax.ipynb rename to doc/tutorials/msc/tests/translate-relax.ipynb diff --git a/doc/test/tutorials/msc/translate-relay.ipynb b/doc/tutorials/msc/tests/translate-relay.ipynb similarity index 100% rename from doc/test/tutorials/msc/translate-relay.ipynb rename to doc/tutorials/msc/tests/translate-relay.ipynb diff --git a/doc/test/tutorials/msc/translate-tensorflow.ipynb b/doc/tutorials/msc/tests/translate-tensorflow.ipynb similarity index 100% rename from doc/test/tutorials/msc/translate-tensorflow.ipynb rename to doc/tutorials/msc/tests/translate-tensorflow.ipynb diff --git a/doc/test/tutorials/msc/translate-tensorrt.ipynb b/doc/tutorials/msc/tests/translate-tensorrt.ipynb similarity index 100% rename from doc/test/tutorials/msc/translate-tensorrt.ipynb rename to doc/tutorials/msc/tests/translate-tensorrt.ipynb diff --git a/doc/test/tutorials/msc/translate-torch.ipynb b/doc/tutorials/msc/tests/translate-torch.ipynb similarity index 100% rename from doc/test/tutorials/msc/translate-torch.ipynb rename to doc/tutorials/msc/tests/translate-torch.ipynb diff --git a/doc/test/tutorials/msc/translate.ipynb b/doc/tutorials/msc/tests/translate.ipynb similarity index 100% rename from doc/test/tutorials/msc/translate.ipynb rename to doc/tutorials/msc/tests/translate.ipynb diff --git a/doc/tutorials/msc/tools/index.md b/doc/tutorials/msc/tools/index.md index 7e858fc3..f850b81a 100644 --- a/doc/tutorials/msc/tools/index.md +++ b/doc/tutorials/msc/tools/index.md @@ -2,6 +2,8 @@ ```{toctree} intro -MSCProcessor +pruner +quantizer +tracker test ``` diff --git a/doc/tutorials/msc/tools/intro.md b/doc/tutorials/msc/tools/intro.md index d42a67c8..715b1766 100644 --- a/doc/tutorials/msc/tools/intro.md +++ b/doc/tutorials/msc/tools/intro.md @@ -2,23 +2,14 @@ 参考: [【我与TVM二三事 后篇(4)】MSC之MSCTools](https://zhuanlan.zhihu.com/p/680796444) -MSCTools 是 MSC 架构独有的核心设计之一,也是 MSC 区别于传统 AI 编译器或压缩工具链的最重要的点。MSCTools 的设计也经历了多次迭代,从一开始的耦合于 torch 的工具链,到依附于 TensorRT 的特定功能模块,再到最终的完全从训练和部署框架解耦变成独立体系,可以说一路走来都是教训不断推动着我对模型压缩逻辑进行拆分解耦。 +MSCTools 是 MSC 架构独有的核心设计之一,也是 MSC 区别于传统 AI 编译器或压缩工具链的最重要的点。MSCTools 的设计也经历了多次迭代,从一开始的耦合于 torch 的工具链,到依附于 TensorRT 的特定功能模块,再到最终的完全从训练和部署框架解耦变成独立体系,可以说一路走来都是教训不断推动着对模型压缩逻辑进行拆分解耦。 -一方面,压缩算法和训练框架耦合太深会存在对硬件信息失去感知的情况。例如随着硬件配套的生态逐步迭代,图优化和算子融合的策略也会持续更新,这会影响到量化策略,这样在训练框架中开发量化就不得不持续去反推硬件行为,随着硬件的选择越来越多,这种方式的维护成本会迅速增加。 +一方面,压缩算法和训练框架耦合太深会存在对硬件信息失去感知的情况。例如随着硬件配套的生态逐步迭代,计算图优化和算子融合的策略也会持续更新,这会影响到量化策略,这样在训练框架中开发量化就不得不持续去反推硬件行为,随着硬件的选择越来越多,这种方式的维护成本会迅速增加。 -另一方面,压缩工具和部署框架耦合太深则放弃了训练能力,基本只能选各种PTQ的方式,对于压缩效果更好但需要训练的稀疏化、剪枝和蒸馏等技术基本只能放弃。而比较好的压缩技术往往都需要配合训练,所以耦合部署框架的压缩工具上限并不高。 +另一方面,压缩工具和部署框架耦合太深则放弃了训练能力,基本只能选各种 PTQ 的方式,对于压缩效果更好但需要训练的稀疏化、剪枝和蒸馏等技术基本只能放弃。而比较好的压缩技术往往都需要配合训练,所以耦合部署框架的压缩工具上限并不高。 经历了各种坑之后,我在设计MSC时选择将压缩工具抽象出来,这样:1.方便新的压缩工具的开发,以及向新训练/部署框架集成;2.统一管理调配各种压缩工具,从而实现不同压缩算法的配合,如剪枝+量化+蒸馏。但相应的开发MSCTool有一些代价,即压缩算法开发过程使用的基础数据变成了MSCGraph这一层抽象,而不是torch.Module、tvm.IRModule这类具体的框架计算图。 -## MSCTool流程 MSCTool和MSCRunner共同作用对模型进行压缩,基本流程如下: ![](../images/tools.jpg) - -### make plan -如果MSCTool配置中指定的plan文件不存在,则首先通过MSCRunner的build方法构建用于生成plan的runnable对象,build过程中MSCTool会根据strategys的配置在构建model时插入埋点,例如数据收集节点(量化)和weight channel分析节点(剪枝、稀疏化)。之后MSCRunner加载数据集进行forward,MSCTool插入的节点完成数据的分析生成plan。 - -### apply plan -如果MSCTool找到了plan文件或者通过2.1生成了plan文件,则在MSCRunner构建runnable对象时根据plan对tensor进行改造,例如对tensorr进行量化/反量化操作(量化),或者对weight进行修剪(剪枝)。 - -MSCTool将压缩过程分成两个阶段主要是考虑和模型交付系统配合,一般make plan阶段对资源的消耗较大,且可能依赖蒸馏、训练等费时手段,所以对此阶段结束时使用plan保存结果,起到checkpoint作用。交付系统可以直接加载plan从而跳过make plan这一高消耗阶段。 \ No newline at end of file diff --git a/doc/tutorials/msc/tools/pruner.ipynb b/doc/tutorials/msc/tools/pruner.ipynb new file mode 100644 index 00000000..5144db93 --- /dev/null +++ b/doc/tutorials/msc/tools/pruner.ipynb @@ -0,0 +1,144 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 剪枝" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import set_env\n", + "from pathlib import Path\n", + "\n", + "temp_dir = Path(\".temp\")\n", + "temp_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from tvm.contrib.msc.core.tools import ToolType\n", + "config = {\n", + " \"plan_file\": \"msc_pruner.json\",\n", + " \"strategys\": [\n", + " {\n", + " \"methods\": {\n", + " \"weights\": {\"method_name\": \"per_channel\", \"density\": 0.8},\n", + " \"output\": {\"method_name\": \"per_channel\", \"density\": 0.8},\n", + " }\n", + " }\n", + " ],\n", + "}\n", + "tools = [{\"tool_type\": ToolType.PRUNER, \"tool_config\": config}]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n", + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n" + ] + } + ], + "source": [ + "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n", + "from utils import get_model_info, _test_from_torch\n", + "\n", + "_test_from_torch(\n", + " MSCFramework.TVM, tools, \n", + " get_model_info(MSCFramework.TVM), \n", + " temp_dir,\n", + " training=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 蒸馏" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " \"plan_file\": \"msc_distiller.json\",\n", + " \"strategys\": [\n", + " {\n", + " \"methods\": {\"mark\": \"loss_lp_norm\"},\n", + " \"marks\": [\"loss\"],\n", + " },\n", + " ],\n", + "}\n", + "tools.append({\"tool_type\": ToolType.DISTILLER, \"tool_config\": config})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n" + ] + } + ], + "source": [ + "from utils import get_model_info, _test_from_torch\n", + "\n", + "_test_from_torch(\n", + " MSCFramework.TVM, tools, \n", + " get_model_info(MSCFramework.TVM), \n", + " temp_dir,\n", + " training=False\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xxx", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/tutorials/msc/tools/quantizer.ipynb b/doc/tutorials/msc/tools/quantizer.ipynb new file mode 100644 index 00000000..f1a5870d --- /dev/null +++ b/doc/tutorials/msc/tools/quantizer.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 量化" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import set_env\n", + "from pathlib import Path\n", + "\n", + "temp_dir = Path(\".temp\")\n", + "temp_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from tvm.contrib.msc.core.tools import ToolType\n", + "# pylint: disable=import-outside-toplevel\n", + "from tvm.contrib.msc.core.tools.quantize import QuantizeStage\n", + "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n", + "\n", + "run_type = MSCFramework.MSC\n", + "if run_type == MSCFramework.TENSORRT:\n", + " config = {\"plan_file\": \"msc_quantizer.json\", \"strategys\": []}\n", + "else:\n", + " op_types = [\"nn.conv2d\", \"msc.conv2d_bias\", \"msc.linear\", \"msc.linear_bias\"]\n", + " config = {\n", + " \"plan_file\": \"msc_quantizer.json\",\n", + " \"strategys\": [\n", + " {\n", + " \"methods\": {\n", + " \"input\": \"gather_maxmin\",\n", + " \"output\": \"gather_maxmin\",\n", + " \"weights\": \"gather_max_per_channel\",\n", + " },\n", + " \"op_types\": op_types,\n", + " \"stages\": [QuantizeStage.GATHER],\n", + " },\n", + " {\n", + " \"methods\": {\"input\": \"calibrate_maxmin\", \"output\": \"calibrate_maxmin\"},\n", + " \"op_types\": op_types,\n", + " \"stages\": [QuantizeStage.CALIBRATE],\n", + " },\n", + " {\n", + " \"methods\": {\n", + " \"input\": \"quantize_normal\",\n", + " \"weights\": \"quantize_normal\",\n", + " \"output\": \"dequantize_normal\",\n", + " },\n", + " \"op_types\": op_types,\n", + " },\n", + " ],\n", + " }\n", + "tools = [{\"tool_type\": ToolType.QUANTIZER, \"tool_config\": config}]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n", + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n" + ] + } + ], + "source": [ + "from utils import get_model_info, _test_from_torch\n", + "\n", + "_test_from_torch(\n", + " MSCFramework.TVM, tools, \n", + " get_model_info(MSCFramework.TVM), \n", + " temp_dir,\n", + " training=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 蒸馏" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " \"plan_file\": \"msc_distiller.json\",\n", + " \"strategys\": [\n", + " {\n", + " \"methods\": {\"mark\": \"loss_lp_norm\"},\n", + " \"marks\": [\"loss\"],\n", + " },\n", + " ],\n", + "}\n", + "tools.append({\"tool_type\": ToolType.DISTILLER, \"tool_config\": config})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n" + ] + } + ], + "source": [ + "from utils import get_model_info, _test_from_torch\n", + "\n", + "_test_from_torch(\n", + " MSCFramework.TVM, tools, \n", + " get_model_info(MSCFramework.TVM), \n", + " temp_dir,\n", + " training=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ai", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/tutorials/msc/tools/set_env.py b/doc/tutorials/msc/tools/set_env.py new file mode 100644 index 00000000..798095c7 --- /dev/null +++ b/doc/tutorials/msc/tools/set_env.py @@ -0,0 +1,7 @@ +import sys +from pathlib import Path +root_dir = Path(__file__).resolve().parents[4] +sys.path.extend([ + f"{root_dir}/tests" +]) +import env diff --git a/doc/tutorials/msc/tools/test.ipynb b/doc/tutorials/msc/tools/test.ipynb index d8eceed9..b520f16e 100644 --- a/doc/tutorials/msc/tools/test.ipynb +++ b/doc/tutorials/msc/tools/test.ipynb @@ -7,6 +7,53 @@ "# MSC 工具测试" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc\n" + ] + } + ], + "source": [ + "%cd ..\n", + "import set_env" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from tvm.contrib.msc.core.tools import ToolType" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'pruner'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ToolType.PRUNER" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -67,324 +114,6 @@ "```\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Licensed to the Apache Software Foundation (ASF) under one\n", - "# or more contributor license agreements. See the NOTICE file\n", - "# distributed with this work for additional information\n", - "# regarding copyright ownership. The ASF licenses this file\n", - "# to you under the Apache License, Version 2.0 (the\n", - "# \"License\"); you may not use this file except in compliance\n", - "# with the License. You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing,\n", - "# software distributed under the License is distributed on an\n", - "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", - "# KIND, either express or implied. See the License for the\n", - "# specific language governing permissions and limitations\n", - "# under the License.\n", - "\n", - "\"\"\" Test Tools in MSC. \"\"\"\n", - "\n", - "import json\n", - "import pytest\n", - "import torch\n", - "\n", - "import tvm.testing\n", - "from tvm.contrib.msc.pipeline import MSCManager\n", - "from tvm.contrib.msc.core.tools import ToolType\n", - "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n", - "from tvm.contrib.msc.core import utils as msc_utils\n", - "\n", - "requires_tensorrt = pytest.mark.skipif(\n", - " tvm.get_global_func(\"relax.ext.tensorrt\", True) is None,\n", - " reason=\"TENSORRT is not enabled\",\n", - ")\n", - "\n", - "\n", - "def _get_config(\n", - " model_type,\n", - " compile_type,\n", - " tools,\n", - " inputs,\n", - " outputs,\n", - " atol=1e-2,\n", - " rtol=1e-2,\n", - " optimize_type=None,\n", - "):\n", - " \"\"\"Get msc config\"\"\"\n", - "\n", - " path = \"_\".join([\"test_tools\", model_type, compile_type] + [t[\"tool_type\"] for t in tools])\n", - " return {\n", - " \"workspace\": msc_utils.msc_dir(path),\n", - " \"verbose\": \"critical\",\n", - " \"model_type\": model_type,\n", - " \"inputs\": inputs,\n", - " \"outputs\": outputs,\n", - " \"dataset\": {\"prepare\": {\"loader\": \"from_random\", \"max_iter\": 5}},\n", - " \"tools\": tools,\n", - " \"prepare\": {\"profile\": {\"benchmark\": {\"repeat\": 10}}},\n", - " \"baseline\": {\n", - " \"run_type\": model_type,\n", - " \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n", - " },\n", - " \"optimize\": {\n", - " \"run_type\": optimize_type or model_type,\n", - " \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n", - " },\n", - " \"compile\": {\n", - " \"run_type\": compile_type,\n", - " \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n", - " },\n", - " }\n", - "\n", - "\n", - "def get_tools(tool_type, use_distill=False, run_type=MSCFramework.MSC):\n", - " \"\"\"Get config for the tool\"\"\"\n", - "\n", - " tools = []\n", - " if tool_type == ToolType.PRUNER:\n", - " config = {\n", - " \"plan_file\": \"msc_pruner.json\",\n", - " \"strategys\": [\n", - " {\n", - " \"methods\": {\n", - " \"weights\": {\"method_name\": \"per_channel\", \"density\": 0.8},\n", - " \"output\": {\"method_name\": \"per_channel\", \"density\": 0.8},\n", - " }\n", - " }\n", - " ],\n", - " }\n", - " tools.append({\"tool_type\": ToolType.PRUNER, \"tool_config\": config})\n", - " elif tool_type == ToolType.QUANTIZER:\n", - " # pylint: disable=import-outside-toplevel\n", - " from tvm.contrib.msc.core.tools.quantize import QuantizeStage\n", - "\n", - " if run_type == MSCFramework.TENSORRT:\n", - " config = {\"plan_file\": \"msc_quantizer.json\", \"strategys\": []}\n", - " else:\n", - " op_types = [\"nn.conv2d\", \"msc.conv2d_bias\", \"msc.linear\", \"msc.linear_bias\"]\n", - " config = {\n", - " \"plan_file\": \"msc_quantizer.json\",\n", - " \"strategys\": [\n", - " {\n", - " \"methods\": {\n", - " \"input\": \"gather_maxmin\",\n", - " \"output\": \"gather_maxmin\",\n", - " \"weights\": \"gather_max_per_channel\",\n", - " },\n", - " \"op_types\": op_types,\n", - " \"stages\": [QuantizeStage.GATHER],\n", - " },\n", - " {\n", - " \"methods\": {\"input\": \"calibrate_maxmin\", \"output\": \"calibrate_maxmin\"},\n", - " \"op_types\": op_types,\n", - " \"stages\": [QuantizeStage.CALIBRATE],\n", - " },\n", - " {\n", - " \"methods\": {\n", - " \"input\": \"quantize_normal\",\n", - " \"weights\": \"quantize_normal\",\n", - " \"output\": \"dequantize_normal\",\n", - " },\n", - " \"op_types\": op_types,\n", - " },\n", - " ],\n", - " }\n", - " tools.append({\"tool_type\": ToolType.QUANTIZER, \"tool_config\": config})\n", - " elif tool_type == ToolType.TRACKER:\n", - " # pylint: disable=import-outside-toplevel\n", - " from tvm.contrib.msc.core.utils import MSCStage\n", - "\n", - " config = {\n", - " \"plan_file\": \"msc_tracker.json\",\n", - " \"strategys\": [\n", - " {\n", - " \"methods\": {\n", - " \"output\": {\n", - " \"method_name\": \"save_compared\",\n", - " \"compare_to\": {\n", - " MSCStage.OPTIMIZE: [MSCStage.BASELINE],\n", - " MSCStage.COMPILE: [MSCStage.OPTIMIZE, MSCStage.BASELINE],\n", - " },\n", - " }\n", - " },\n", - " \"op_types\": [\"nn.relu\"],\n", - " }\n", - " ],\n", - " }\n", - " tools.append({\"tool_type\": ToolType.TRACKER, \"tool_config\": config})\n", - " if use_distill:\n", - " config = {\n", - " \"plan_file\": \"msc_distiller.json\",\n", - " \"strategys\": [\n", - " {\n", - " \"methods\": {\"mark\": \"loss_lp_norm\"},\n", - " \"marks\": [\"loss\"],\n", - " },\n", - " ],\n", - " }\n", - " tools.append({\"tool_type\": ToolType.DISTILLER, \"tool_config\": config})\n", - " return tools\n", - "\n", - "\n", - "def _get_torch_model(name, training=False):\n", - " \"\"\"Get model from torch vision\"\"\"\n", - "\n", - " # pylint: disable=import-outside-toplevel\n", - " try:\n", - " import torchvision\n", - "\n", - " model = getattr(torchvision.models, name)()\n", - " if training:\n", - " model = model.train()\n", - " else:\n", - " model = model.eval()\n", - " return model\n", - " except: # pylint: disable=bare-except\n", - " print(\"please install torchvision package\")\n", - " return None\n", - "\n", - "\n", - "def _check_manager(manager, expected_info):\n", - " \"\"\"Check the manager results\"\"\"\n", - "\n", - " model_info = manager.get_runtime().model_info\n", - " passed, err = True, \"\"\n", - " if not manager.report[\"success\"]:\n", - " passed = False\n", - " err = \"Failed to run pipe for {} -> {}\".format(manager.model_type, manager.compile_type)\n", - " if not msc_utils.dict_equal(model_info, expected_info):\n", - " passed = False\n", - " err = \"Model info {} mismatch with expected {}\".format(model_info, expected_info)\n", - " manager.destory()\n", - " if not passed:\n", - " raise Exception(\"{}\\nReport:{}\".format(err, json.dumps(manager.report, indent=2)))\n", - "\n", - "\n", - "def _test_from_torch(\n", - " compile_type,\n", - " tools,\n", - " expected_info,\n", - " training=False,\n", - " atol=1e-1,\n", - " rtol=1e-1,\n", - " optimize_type=None,\n", - "):\n", - " torch_model = _get_torch_model(\"resnet50\", training)\n", - " if torch_model:\n", - " if torch.cuda.is_available():\n", - " torch_model = torch_model.to(torch.device(\"cuda:0\"))\n", - " config = _get_config(\n", - " MSCFramework.TORCH,\n", - " compile_type,\n", - " tools,\n", - " inputs=[[\"input_0\", [1, 3, 224, 224], \"float32\"]],\n", - " outputs=[\"output\"],\n", - " atol=atol,\n", - " rtol=rtol,\n", - " optimize_type=optimize_type,\n", - " )\n", - " manager = MSCManager(torch_model, config)\n", - " manager.run_pipe()\n", - " _check_manager(manager, expected_info)\n", - "\n", - "\n", - "def get_model_info(compile_type):\n", - " \"\"\"Get the model info\"\"\"\n", - "\n", - " if compile_type == MSCFramework.TVM:\n", - " return {\n", - " \"inputs\": [\n", - " {\"name\": \"input_0\", \"shape\": [1, 3, 224, 224], \"dtype\": \"float32\", \"layout\": \"NCHW\"}\n", - " ],\n", - " \"outputs\": [{\"name\": \"output\", \"shape\": [1, 1000], \"dtype\": \"float32\", \"layout\": \"NC\"}],\n", - " \"nodes\": {\n", - " \"total\": 229,\n", - " \"input\": 1,\n", - " \"nn.conv2d\": 53,\n", - " \"nn.batch_norm\": 53,\n", - " \"get_item\": 53,\n", - " \"nn.relu\": 49,\n", - " \"nn.max_pool2d\": 1,\n", - " \"add\": 16,\n", - " \"nn.adaptive_avg_pool2d\": 1,\n", - " \"reshape\": 1,\n", - " \"msc.linear_bias\": 1,\n", - " },\n", - " }\n", - " if compile_type == MSCFramework.TENSORRT:\n", - " return {\n", - " \"inputs\": [\n", - " {\"name\": \"input_0\", \"shape\": [1, 3, 224, 224], \"dtype\": \"float32\", \"layout\": \"NCHW\"}\n", - " ],\n", - " \"outputs\": [{\"name\": \"output\", \"shape\": [1, 1000], \"dtype\": \"float32\", \"layout\": \"\"}],\n", - " \"nodes\": {\"total\": 2, \"input\": 1, \"msc_tensorrt\": 1},\n", - " }\n", - " raise TypeError(\"Unexpected compile_type \" + str(compile_type))\n", - "\n", - "\n", - "@pytest.mark.parametrize(\"tool_type\", [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER])\n", - "def test_tvm_tool(tool_type):\n", - " \"\"\"Test tools for tvm\"\"\"\n", - "\n", - " tools = get_tools(tool_type)\n", - " _test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False)\n", - "\n", - "\n", - "@pytest.mark.parametrize(\"tool_type\", [ToolType.PRUNER, ToolType.QUANTIZER])\n", - "def test_tvm_distill(tool_type):\n", - " \"\"\"Test tools for tvm with distiller\"\"\"\n", - "\n", - " tools = get_tools(tool_type, use_distill=True)\n", - " _test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False)\n", - "\n", - "\n", - "@requires_tensorrt\n", - "@pytest.mark.parametrize(\n", - " \"tool_type\",\n", - " [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER],\n", - ")\n", - "def test_tensorrt_tool(tool_type):\n", - " \"\"\"Test tools for tensorrt\"\"\"\n", - "\n", - " tools = get_tools(tool_type, run_type=MSCFramework.TENSORRT)\n", - " if tool_type == ToolType.QUANTIZER:\n", - " optimize_type = MSCFramework.TENSORRT\n", - " else:\n", - " optimize_type = None\n", - " _test_from_torch(\n", - " MSCFramework.TENSORRT,\n", - " tools,\n", - " get_model_info(MSCFramework.TENSORRT),\n", - " training=False,\n", - " atol=1e-1,\n", - " rtol=1e-1,\n", - " optimize_type=optimize_type,\n", - " )\n", - "\n", - "\n", - "@requires_tensorrt\n", - "@pytest.mark.parametrize(\"tool_type\", [ToolType.PRUNER])\n", - "def test_tensorrt_distill(tool_type):\n", - " \"\"\"Test tools for tensorrt with distiller\"\"\"\n", - "\n", - " tools = get_tools(tool_type, use_distill=True)\n", - " _test_from_torch(\n", - " MSCFramework.TENSORRT, tools, get_model_info(MSCFramework.TENSORRT), training=False\n", - " )\n", - "\n", - "\n", - "if __name__ == \"__main__\":\n", - " tvm.testing.main()\n" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/doc/tutorials/msc/tools/tracker.ipynb b/doc/tutorials/msc/tools/tracker.ipynb new file mode 100644 index 00000000..7bc36267 --- /dev/null +++ b/doc/tutorials/msc/tools/tracker.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 追踪" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import set_env\n", + "from pathlib import Path\n", + "\n", + "temp_dir = Path(\".temp\")\n", + "temp_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from tvm.contrib.msc.core.tools import ToolType\n", + "from tvm.contrib.msc.core.utils import MSCStage\n", + "\n", + "config = {\n", + " \"plan_file\": \"msc_tracker.json\",\n", + " \"strategys\": [\n", + " {\n", + " \"methods\": {\n", + " \"output\": {\n", + " \"method_name\": \"save_compared\",\n", + " \"compare_to\": {\n", + " MSCStage.OPTIMIZE: [MSCStage.BASELINE],\n", + " MSCStage.COMPILE: [MSCStage.OPTIMIZE, MSCStage.BASELINE],\n", + " },\n", + " }\n", + " },\n", + " \"op_types\": [\"nn.relu\"],\n", + " }\n", + " ],\n", + "}\n", + "tools = [{\"tool_type\": ToolType.TRACKER, \"tool_config\": config}]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n", + "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n" + ] + } + ], + "source": [ + "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n", + "from utils import get_model_info, _test_from_torch\n", + "\n", + "_test_from_torch(\n", + " MSCFramework.TVM, tools, \n", + " get_model_info(MSCFramework.TVM), \n", + " temp_dir,\n", + " training=False\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xxx", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/tutorials/msc/tools/utils.py b/doc/tutorials/msc/tools/utils.py new file mode 100644 index 00000000..73379662 --- /dev/null +++ b/doc/tutorials/msc/tools/utils.py @@ -0,0 +1,138 @@ +import json +import torch + +# import tvm.testing +from tvm.contrib.msc.pipeline import MSCManager +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + +def _get_config( + model_type, + compile_type, + tools, + inputs, + outputs, + temp_dir, + atol=1e-2, + rtol=1e-2, + optimize_type=None, +): + """Get msc config""" + + path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools]) + return { + "workspace": msc_utils.msc_dir(temp_dir/path, keep_history=False), + "verbose": "critical", + "model_type": model_type, + "inputs": inputs, + "outputs": outputs, + "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, + "tools": tools, + "prepare": {"profile": {"benchmark": {"repeat": 10}}}, + "baseline": { + "run_type": model_type, + "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, + }, + "optimize": { + "run_type": optimize_type or model_type, + "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, + }, + "compile": { + "run_type": compile_type, + "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, + }, + } +def _get_torch_model(name, training=False): + """Get model from torch vision""" + + # pylint: disable=import-outside-toplevel + try: + import torchvision + + model = getattr(torchvision.models, name)() + if training: + model = model.train() + else: + model = model.eval() + return model + except: # pylint: disable=bare-except + print("please install torchvision package") + return None + +def _check_manager(manager, expected_info): + """Check the manager results""" + + model_info = manager.get_runtime().model_info + passed, err = True, "" + if not manager.report["success"]: + passed = False + err = "Failed to run pipe for {} -> {}".format(manager.model_type, manager.compile_type) + if not msc_utils.dict_equal(model_info, expected_info): + passed = False + err = "Model info {} mismatch with expected {}".format(model_info, expected_info) + manager.destory() + if not passed: + raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2))) + + +def _test_from_torch( + compile_type, + tools, + expected_info, + temp_dir, + training=False, + atol=1e-1, + rtol=1e-1, + optimize_type=None, +): + torch_model = _get_torch_model("resnet50", training) + if torch_model: + if torch.cuda.is_available(): + torch_model = torch_model.to(torch.device("cuda:0")) + config = _get_config( + MSCFramework.TORCH, + compile_type, + tools, + inputs=[["input_0", [1, 3, 224, 224], "float32"]], + outputs=["output"], + temp_dir=temp_dir, + atol=atol, + rtol=rtol, + optimize_type=optimize_type, + ) + manager = MSCManager(torch_model, config) + manager.run_pipe() + _check_manager(manager, expected_info) + +def get_model_info(compile_type): + """Get the model info""" + + if compile_type == MSCFramework.TVM: + return { + "inputs": [ + {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], + "nodes": { + "total": 229, + "input": 1, + "nn.conv2d": 53, + "nn.batch_norm": 53, + "get_item": 53, + "nn.relu": 49, + "nn.max_pool2d": 1, + "add": 16, + "nn.adaptive_avg_pool2d": 1, + "reshape": 1, + "msc.linear_bias": 1, + }, + } + if compile_type == MSCFramework.TENSORRT: + return { + "inputs": [ + {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}], + "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1}, + } + raise TypeError("Unexpected compile_type " + str(compile_type))