Skip to content

Commit

Permalink
新增模板匹配文档
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxinwei committed May 23, 2024
1 parent a0d12ab commit 378e3f0
Show file tree
Hide file tree
Showing 16 changed files with 3,059 additions and 223 deletions.
4 changes: 2 additions & 2 deletions doc/faqs/gcc.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ OSError: /media/pc/data/tmp/cache/conda/envs/py311/bin/../lib/libstdc++.so.6: ve
pip uninstall scipy
pip install scipy -i https://pypi.tuna.tsinghua.edu.cn/simple
```
- 解决办法二:
- 解决办法二[libstdcxx-ng](https://libraries.io/conda/libstdcxx-ng)
```bash
conda install -c anaconda libstdcxx-ng
conda install -c conda-forge libstdcxx-ng
```
- 解决办法三:
- 检查是否存在:
Expand Down
1 change: 1 addition & 0 deletions doc/faqs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
tvm-docker<https://github.com/xinetzone/tvm-docker>
gcc
caffe
cuda
```
1 change: 1 addition & 0 deletions doc/read/relay/frontend/common/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
```{toctree}
AttrConverter
tag-span
faqs
```
1 change: 1 addition & 0 deletions doc/tutorials/frontend/pytorch-tvm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

```{toctree}
primitive
forward
quant
```
2 changes: 1 addition & 1 deletion doc/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pass/index
roofline/index
deploy/index
auto-quantize/index
rewrite/index
pattern/index
partition/index
vta/index
msc/index
Expand Down
309 changes: 309 additions & 0 deletions doc/tutorials/pattern/dataflow/fuse.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 融合模式"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"ROOT = Path(\".\").resolve().parents[3]\n",
"sys.path.extend([f\"{ROOT}/tests\", f\"{ROOT}/src\"])\n",
"# # from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span\n",
"from tools.torch_utils import verify_model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import tvm\n",
"from tvm import relay\n",
"from tvm.relay.build_module import bind_params_by_name\n",
"from tvm.relay.dataflow_pattern import *\n",
"from tvm.relay.testing import run_opt_pass\n",
"\n",
"# NB: 1 corresponds to the C++ enum that specicfies this\n",
"# we loose the type safety due to the Python/C++ calling\n",
"# convention.\n",
"K_ELEMWISE = 0\n",
"K_BROADCAST = 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## not_fuse_multi_diamond"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Pattern\n",
"is_conv2d = is_op(\"nn.conv2d\")(wildcard(), wildcard())\n",
"path1 = is_op(\"nn.relu\")(is_conv2d)\n",
"path2 = is_op(\"nn.leaky_relu\")(is_conv2d)\n",
"diamond = is_op(\"add\")(path1, path2)\n",
"\n",
"# Expr\n",
"inp = relay.var(\"input\")\n",
"weight = relay.var(\"weight\")\n",
"conv2d = relay.op.nn.conv2d(inp, weight)\n",
"relu = relay.op.nn.relu(conv2d)\n",
"leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)\n",
"out = relu + leaky_relu\n",
"out = out + conv2d\n",
"# Check\n",
"assert not diamond.match(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BN 融合"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class BatchnormCallback(DFPatternCallback):\n",
" def __init__(self):\n",
" super(BatchnormCallback, self).__init__()\n",
" self.x = wildcard()\n",
" self.var = wildcard()\n",
" self.mean = wildcard()\n",
" self.beta = wildcard()\n",
" self.gamma = wildcard()\n",
" self.eps = is_constant()\n",
"\n",
" self.pattern = (\n",
" self.gamma * (self.x - self.mean) / is_op(\"sqrt\")(self.var + self.eps) + self.beta\n",
" )\n",
"\n",
" def callback(self, pre, post, node_map):\n",
" x = node_map[self.x][0]\n",
" var = node_map[self.var][0]\n",
" mean = node_map[self.mean][0]\n",
" beta = node_map[self.beta][0]\n",
" gamma = node_map[self.gamma][0]\n",
" eps = node_map[self.eps][0]\n",
" return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=eps.data.numpy().item())[0]\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def test_fuse_batchnorm():\n",
" x = relay.var(\"x\")\n",
" var = relay.var(\"var\")\n",
" mean = relay.var(\"mean\")\n",
" beta = relay.var(\"beta\")\n",
" gamma = relay.var(\"gamma\")\n",
"\n",
" BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n",
"\n",
" out = rewrite(BatchnormCallback(), BN)\n",
" assert tvm.ir.structural_equal(\n",
" out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n",
" )\n",
"\n",
"\n",
"def test_no_fuse_batchnorm():\n",
" x = relay.var(\"x\")\n",
" var = relay.var(\"var\")\n",
" mean = relay.var(\"mean\")\n",
" beta = relay.var(\"beta\")\n",
" gamma = relay.var(\"gamma\")\n",
"\n",
" fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta\n",
"\n",
" out = rewrite(BatchnormCallback(), fake_BN)\n",
" assert tvm.ir.structural_equal(out, fake_BN)\n",
"\n",
"\n",
"def test_fuse_double_batchnorm():\n",
" x = relay.var(\"x\")\n",
" var = relay.var(\"var\")\n",
" mean = relay.var(\"mean\")\n",
" beta = relay.var(\"beta\")\n",
" gamma = relay.var(\"gamma\")\n",
"\n",
" BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n",
" BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n",
"\n",
" out = rewrite(BatchnormCallback(), BN2)\n",
"\n",
" bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n",
" bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0]\n",
"\n",
" assert tvm.ir.structural_equal(out, bn2)\n",
"\n",
"\n",
"def test_partial_fuse_double_batchnorm():\n",
" x = relay.var(\"x\")\n",
" var = relay.var(\"var\")\n",
" mean = relay.var(\"mean\")\n",
" beta = relay.var(\"beta\")\n",
" gamma = relay.var(\"gamma\")\n",
"\n",
" BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta\n",
" BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n",
"\n",
" out = rewrite(BatchnormCallback(), BN2)\n",
"\n",
" bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0]\n",
"\n",
" assert tvm.ir.structural_equal(out, bn2)\n",
"\n",
"\n",
"def test_fuse_batchnorm_commutation():\n",
" x = relay.var(\"x\")\n",
" var = relay.var(\"var\")\n",
" mean = relay.var(\"mean\")\n",
" beta = relay.var(\"beta\")\n",
" gamma = relay.var(\"gamma\")\n",
"\n",
" # commute add\n",
" BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5))\n",
" out = rewrite(BatchnormCallback(), BN)\n",
" assert tvm.ir.structural_equal(\n",
" out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n",
" )\n",
"\n",
" # associate divide/multiply\n",
" BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n",
" out = rewrite(BatchnormCallback(), BN)\n",
" assert tvm.ir.structural_equal(\n",
" out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n",
" )\n",
"\n",
" # associate multiply/divide\n",
" BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta\n",
" out = rewrite(BatchnormCallback(), BN)\n",
" assert tvm.ir.structural_equal(\n",
" out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n",
" )\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## quadruple_rewrite_dominator"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class DominatorRemovalCallback(DFPatternCallback):\n",
" def __init__(self):\n",
" super(DominatorRemovalCallback, self).__init__()\n",
" self.inp = wildcard()\n",
" self.weight = wildcard()\n",
" is_conv2d = is_op(\"nn.conv2d\")(self.inp, self.weight)\n",
" is_unary_elemwise = (wildcard().has_attr({\"TOpPattern\": K_ELEMWISE}))(\n",
" wildcard()\n",
" ) | is_op(\"add\")(wildcard(), wildcard())\n",
" reduction = is_op(\"add\")(wildcard(), wildcard())\n",
" self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)\n",
"\n",
" def callback(self, pre, post, node_map):\n",
" inp = node_map[self.inp][0]\n",
" weight = node_map[self.weight][0]\n",
" return relay.op.nn.conv2d(inp, weight)\n",
"\n",
"inp = relay.var(\"input\")\n",
"weight = relay.var(\"weight\")\n",
"# Classic Diamond\n",
"conv2d = relay.op.nn.conv2d(inp, weight)\n",
"relu = relay.op.nn.relu(conv2d)\n",
"relu = relay.op.nn.relu(relu)\n",
"leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)\n",
"out = relu + leaky_relu\n",
"\n",
"# Deeper Branch\n",
"conv2d = relay.op.nn.conv2d(out, weight)\n",
"relu = relay.op.nn.relu(conv2d)\n",
"relu = relay.op.nn.relu(relu)\n",
"relu = relay.op.tanh(relu)\n",
"leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)\n",
"out = relu + leaky_relu\n",
"\n",
"# Single Branch\n",
"conv2d = relay.op.nn.conv2d(out, weight)\n",
"relu = relay.op.nn.relu(conv2d)\n",
"relu = relay.op.nn.relu(relu)\n",
"tanh = relay.op.tanh(relu)\n",
"out = relu + tanh\n",
"\n",
"# Fuzzy path/nested Diamond\n",
"conv2d = relay.op.nn.conv2d(out, weight)\n",
"relu = relay.op.nn.relu(conv2d)\n",
"relu = relu + relu\n",
"tanh = relay.op.tanh(relu)\n",
"leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)\n",
"out = tanh + leaky_relu\n",
"one = relay.op.nn.conv2d(inp, weight)\n",
"two = relay.op.nn.conv2d(one, weight)\n",
"three = relay.op.nn.conv2d(two, weight)\n",
"four = relay.op.nn.conv2d(three, weight)\n",
"\n",
"assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py312x",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
10 changes: 10 additions & 0 deletions doc/tutorials/pattern/dataflow/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# 数据流模式

```{toctree}
node
match
rewrite
fuse
simplify
partition
```
Loading

0 comments on commit 378e3f0

Please sign in to comment.