-
Notifications
You must be signed in to change notification settings - Fork 118
/
Copy pathcustom.py
94 lines (86 loc) · 3.29 KB
/
custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from .generated_ops import CustomOperator
from ...schemas.tflite import schema_generated as tflite
HAS_FLEXBUFFER = False
try:
from flatbuffers import flexbuffers
HAS_FLEXBUFFER = True
except ImportError:
pass
class TFLiteDetectionPostprocessOperator(CustomOperator):
def __init__(
self,
inputs,
outputs,
max_detections: int,
max_classes_per_detection: int,
nms_score_threshold: float,
nms_iou_threshold: float,
num_classes: int,
y_scale: float,
x_scale: float,
h_scale: float,
w_scale: float,
) -> None:
super().__init__(inputs, outputs)
assert HAS_FLEXBUFFER, "TFLITE_DETECTION_POSTPROCESS relies on FlexBuffer, which requires flatbuffers>=2"
self.op.custom_code = "TFLITE_DETECTION_POSTPROCESS"
self.max_detections = max_detections
self.max_classes_per_detection = max_classes_per_detection
self.nms_score_threshold = nms_score_threshold
self.nms_iou_threshold = nms_iou_threshold
self.num_classes = num_classes
self.y_scale = y_scale
self.x_scale = x_scale
self.h_scale = h_scale
self.w_scale = w_scale
def build(self, builder):
fbb = flexbuffers.Builder()
with fbb.Map():
fbb.Int('max_detections', self.max_detections)
fbb.Int('max_classes_per_detection', self.max_classes_per_detection)
fbb.Float('nms_score_threshold', self.nms_score_threshold)
fbb.Float('nms_iou_threshold', self.nms_iou_threshold)
fbb.Int('num_classes', self.num_classes)
fbb.Float('y_scale', self.y_scale)
fbb.Float('x_scale', self.x_scale)
fbb.Float('h_scale', self.h_scale)
fbb.Float('w_scale', self.w_scale)
self.custom_options = fbb.Finish()
return super().build(builder)
class MTKTransposeConvOperator(CustomOperator):
def __init__(
self,
inputs,
outputs,
activation: int = tflite.ActivationFunctionType.NONE,
depth_multiplier: int = 0,
dilation_height_factor: int = 0,
dilation_width_factor: int = 0,
padding_type: int = tflite.Padding.SAME,
stride_height: int = 0,
stride_width: int = 0,
) -> None:
super().__init__(inputs, outputs)
self.op.custom_code = "MTK_TRANSPOSE_CONV"
self.activation = activation
self.depth_multiplier = depth_multiplier
self.dilation_height_factor = dilation_height_factor
self.dilation_width_factor = dilation_width_factor
self.padding_type = padding_type
self.stride_height = stride_height
self.stride_width = stride_width
def build(self, builder):
fbb = flexbuffers.Builder()
fbb.MapFromElements(
{
'activation': self.activation,
'depth_multiplier': self.depth_multiplier,
'dilation_height_factor': self.dilation_height_factor,
'dilation_width_factor': self.dilation_width_factor,
'PaddingType': self.padding_type,
'stride_height': self.stride_height,
'stride_width': self.stride_width,
}
)
self.custom_options = fbb.Finish()
return super().build(builder)