-
Notifications
You must be signed in to change notification settings - Fork 118
/
Copy pathhalf_quantizer.py
84 lines (67 loc) · 2.77 KB
/
half_quantizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import functools
import typing
import igraph as ig
from . import tflite as tfl
from .base import ExtendedOperator
from .graph import CommonGraph
from tinynn.util.util import get_logger
log = get_logger(__name__)
class HalfQuantizer(object):
graph: CommonGraph
def __init__(self, graph) -> None:
super().__init__()
self.graph = graph
self.fuse_tensor_count = 0
self.fuse_attr_count = 0
def create_attr_tensor(
self, tensor: tfl.Tensor, name: str = None, quantization: typing.Optional[tfl.QuantizationParameters] = None
):
if name is None:
if self.fuse_attr_count == 0:
name = 'half_attr'
else:
name = f'half_attr_{self.fuse_attr_count}'
self.fuse_attr_count += 1
return tfl.Tensor(tensor, name, has_buffer=True, quantization=quantization)
def create_transform_tensor(
self, tensor: tfl.Tensor, name: str = None, quantization: typing.Optional[tfl.QuantizationParameters] = None
):
if name is None:
if self.fuse_tensor_count == 0:
name = 'half_transform'
else:
name = f'half_transform_{self.fuse_tensor_count}'
self.fuse_tensor_count += 1
return tfl.Tensor(tensor, name, has_buffer=False, quantization=quantization)
def quantize(self):
self.quantize_pass()
def quantize_pass(self):
filtered_nodes = self.graph.graph.vs.select(functools.partial(is_quantizable_node, graph_converter=self.graph))
actions = []
for node in filtered_nodes:
tn = node['name']
t = self.graph.tensor_map[tn]
c = self.create_attr_tensor(t.tensor.astype('float16'))
new_t = self.create_transform_tensor(t.tensor)
op = tfl.DequantizeOperator([c], [new_t])
self.graph.add_operator(op)
next_ops = set()
node_map = {}
for out_edge in node.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
next_op = next_node['op']
if next_op not in next_ops:
next_ops.add(next_op)
node_map[next_op] = next_node
for next_op in next_ops:
next_node = node_map[next_op]
for i, inp in enumerate(next_op.inputs):
if inp.name == tn:
actions.append((self.graph.replace_operator_input, (next_node, i, new_t)))
for func, args in actions:
func(*args)
def is_quantizable_node(vertex: ig.Vertex, graph_converter: CommonGraph):
return (
vertex['node_type'] == ExtendedOperator.CONSTANT_NODE
and str(graph_converter.tensor_map[vertex['name']].dtype) == 'float32'
)