generated from xinetzone/sphinx-demo
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
liuxinwei
committed
May 23, 2024
1 parent
a0d12ab
commit 378e3f0
Showing
16 changed files
with
3,059 additions
and
223 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,5 @@ | |
tvm-docker<https://github.com/xinetzone/tvm-docker> | ||
gcc | ||
caffe | ||
cuda | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ | |
```{toctree} | ||
AttrConverter | ||
tag-span | ||
faqs | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,5 +2,6 @@ | |
|
||
```{toctree} | ||
primitive | ||
forward | ||
quant | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,309 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# 融合模式" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys\n", | ||
"from pathlib import Path\n", | ||
"ROOT = Path(\".\").resolve().parents[3]\n", | ||
"sys.path.extend([f\"{ROOT}/tests\", f\"{ROOT}/src\"])\n", | ||
"# # from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span\n", | ||
"from tools.torch_utils import verify_model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"import tvm\n", | ||
"from tvm import relay\n", | ||
"from tvm.relay.build_module import bind_params_by_name\n", | ||
"from tvm.relay.dataflow_pattern import *\n", | ||
"from tvm.relay.testing import run_opt_pass\n", | ||
"\n", | ||
"# NB: 1 corresponds to the C++ enum that specicfies this\n", | ||
"# we loose the type safety due to the Python/C++ calling\n", | ||
"# convention.\n", | ||
"K_ELEMWISE = 0\n", | ||
"K_BROADCAST = 1" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## not_fuse_multi_diamond" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Pattern\n", | ||
"is_conv2d = is_op(\"nn.conv2d\")(wildcard(), wildcard())\n", | ||
"path1 = is_op(\"nn.relu\")(is_conv2d)\n", | ||
"path2 = is_op(\"nn.leaky_relu\")(is_conv2d)\n", | ||
"diamond = is_op(\"add\")(path1, path2)\n", | ||
"\n", | ||
"# Expr\n", | ||
"inp = relay.var(\"input\")\n", | ||
"weight = relay.var(\"weight\")\n", | ||
"conv2d = relay.op.nn.conv2d(inp, weight)\n", | ||
"relu = relay.op.nn.relu(conv2d)\n", | ||
"leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)\n", | ||
"out = relu + leaky_relu\n", | ||
"out = out + conv2d\n", | ||
"# Check\n", | ||
"assert not diamond.match(out)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## BN 融合" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class BatchnormCallback(DFPatternCallback):\n", | ||
" def __init__(self):\n", | ||
" super(BatchnormCallback, self).__init__()\n", | ||
" self.x = wildcard()\n", | ||
" self.var = wildcard()\n", | ||
" self.mean = wildcard()\n", | ||
" self.beta = wildcard()\n", | ||
" self.gamma = wildcard()\n", | ||
" self.eps = is_constant()\n", | ||
"\n", | ||
" self.pattern = (\n", | ||
" self.gamma * (self.x - self.mean) / is_op(\"sqrt\")(self.var + self.eps) + self.beta\n", | ||
" )\n", | ||
"\n", | ||
" def callback(self, pre, post, node_map):\n", | ||
" x = node_map[self.x][0]\n", | ||
" var = node_map[self.var][0]\n", | ||
" mean = node_map[self.mean][0]\n", | ||
" beta = node_map[self.beta][0]\n", | ||
" gamma = node_map[self.gamma][0]\n", | ||
" eps = node_map[self.eps][0]\n", | ||
" return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=eps.data.numpy().item())[0]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def test_fuse_batchnorm():\n", | ||
" x = relay.var(\"x\")\n", | ||
" var = relay.var(\"var\")\n", | ||
" mean = relay.var(\"mean\")\n", | ||
" beta = relay.var(\"beta\")\n", | ||
" gamma = relay.var(\"gamma\")\n", | ||
"\n", | ||
" BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n", | ||
"\n", | ||
" out = rewrite(BatchnormCallback(), BN)\n", | ||
" assert tvm.ir.structural_equal(\n", | ||
" out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
"def test_no_fuse_batchnorm():\n", | ||
" x = relay.var(\"x\")\n", | ||
" var = relay.var(\"var\")\n", | ||
" mean = relay.var(\"mean\")\n", | ||
" beta = relay.var(\"beta\")\n", | ||
" gamma = relay.var(\"gamma\")\n", | ||
"\n", | ||
" fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta\n", | ||
"\n", | ||
" out = rewrite(BatchnormCallback(), fake_BN)\n", | ||
" assert tvm.ir.structural_equal(out, fake_BN)\n", | ||
"\n", | ||
"\n", | ||
"def test_fuse_double_batchnorm():\n", | ||
" x = relay.var(\"x\")\n", | ||
" var = relay.var(\"var\")\n", | ||
" mean = relay.var(\"mean\")\n", | ||
" beta = relay.var(\"beta\")\n", | ||
" gamma = relay.var(\"gamma\")\n", | ||
"\n", | ||
" BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n", | ||
" BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n", | ||
"\n", | ||
" out = rewrite(BatchnormCallback(), BN2)\n", | ||
"\n", | ||
" bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n", | ||
" bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0]\n", | ||
"\n", | ||
" assert tvm.ir.structural_equal(out, bn2)\n", | ||
"\n", | ||
"\n", | ||
"def test_partial_fuse_double_batchnorm():\n", | ||
" x = relay.var(\"x\")\n", | ||
" var = relay.var(\"var\")\n", | ||
" mean = relay.var(\"mean\")\n", | ||
" beta = relay.var(\"beta\")\n", | ||
" gamma = relay.var(\"gamma\")\n", | ||
"\n", | ||
" BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta\n", | ||
" BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n", | ||
"\n", | ||
" out = rewrite(BatchnormCallback(), BN2)\n", | ||
"\n", | ||
" bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0]\n", | ||
"\n", | ||
" assert tvm.ir.structural_equal(out, bn2)\n", | ||
"\n", | ||
"\n", | ||
"def test_fuse_batchnorm_commutation():\n", | ||
" x = relay.var(\"x\")\n", | ||
" var = relay.var(\"var\")\n", | ||
" mean = relay.var(\"mean\")\n", | ||
" beta = relay.var(\"beta\")\n", | ||
" gamma = relay.var(\"gamma\")\n", | ||
"\n", | ||
" # commute add\n", | ||
" BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5))\n", | ||
" out = rewrite(BatchnormCallback(), BN)\n", | ||
" assert tvm.ir.structural_equal(\n", | ||
" out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n", | ||
" )\n", | ||
"\n", | ||
" # associate divide/multiply\n", | ||
" BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta\n", | ||
" out = rewrite(BatchnormCallback(), BN)\n", | ||
" assert tvm.ir.structural_equal(\n", | ||
" out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n", | ||
" )\n", | ||
"\n", | ||
" # associate multiply/divide\n", | ||
" BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta\n", | ||
" out = rewrite(BatchnormCallback(), BN)\n", | ||
" assert tvm.ir.structural_equal(\n", | ||
" out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]\n", | ||
" )\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## quadruple_rewrite_dominator" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class DominatorRemovalCallback(DFPatternCallback):\n", | ||
" def __init__(self):\n", | ||
" super(DominatorRemovalCallback, self).__init__()\n", | ||
" self.inp = wildcard()\n", | ||
" self.weight = wildcard()\n", | ||
" is_conv2d = is_op(\"nn.conv2d\")(self.inp, self.weight)\n", | ||
" is_unary_elemwise = (wildcard().has_attr({\"TOpPattern\": K_ELEMWISE}))(\n", | ||
" wildcard()\n", | ||
" ) | is_op(\"add\")(wildcard(), wildcard())\n", | ||
" reduction = is_op(\"add\")(wildcard(), wildcard())\n", | ||
" self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)\n", | ||
"\n", | ||
" def callback(self, pre, post, node_map):\n", | ||
" inp = node_map[self.inp][0]\n", | ||
" weight = node_map[self.weight][0]\n", | ||
" return relay.op.nn.conv2d(inp, weight)\n", | ||
"\n", | ||
"inp = relay.var(\"input\")\n", | ||
"weight = relay.var(\"weight\")\n", | ||
"# Classic Diamond\n", | ||
"conv2d = relay.op.nn.conv2d(inp, weight)\n", | ||
"relu = relay.op.nn.relu(conv2d)\n", | ||
"relu = relay.op.nn.relu(relu)\n", | ||
"leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)\n", | ||
"out = relu + leaky_relu\n", | ||
"\n", | ||
"# Deeper Branch\n", | ||
"conv2d = relay.op.nn.conv2d(out, weight)\n", | ||
"relu = relay.op.nn.relu(conv2d)\n", | ||
"relu = relay.op.nn.relu(relu)\n", | ||
"relu = relay.op.tanh(relu)\n", | ||
"leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)\n", | ||
"out = relu + leaky_relu\n", | ||
"\n", | ||
"# Single Branch\n", | ||
"conv2d = relay.op.nn.conv2d(out, weight)\n", | ||
"relu = relay.op.nn.relu(conv2d)\n", | ||
"relu = relay.op.nn.relu(relu)\n", | ||
"tanh = relay.op.tanh(relu)\n", | ||
"out = relu + tanh\n", | ||
"\n", | ||
"# Fuzzy path/nested Diamond\n", | ||
"conv2d = relay.op.nn.conv2d(out, weight)\n", | ||
"relu = relay.op.nn.relu(conv2d)\n", | ||
"relu = relu + relu\n", | ||
"tanh = relay.op.tanh(relu)\n", | ||
"leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)\n", | ||
"out = tanh + leaky_relu\n", | ||
"one = relay.op.nn.conv2d(inp, weight)\n", | ||
"two = relay.op.nn.conv2d(one, weight)\n", | ||
"three = relay.op.nn.conv2d(two, weight)\n", | ||
"four = relay.op.nn.conv2d(three, weight)\n", | ||
"\n", | ||
"assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "py312x", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# 数据流模式 | ||
|
||
```{toctree} | ||
node | ||
match | ||
rewrite | ||
fuse | ||
simplify | ||
partition | ||
``` |
Oops, something went wrong.