-
Notifications
You must be signed in to change notification settings - Fork 118
/
Copy pathtracer_test.py
89 lines (69 loc) · 2.73 KB
/
tracer_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import re
import unittest
import gc
from tinynn.graph.tracer import trace, model_tracer
from common_utils import collect_custom_models, collect_torchvision_models, prepare_inputs, IS_CI
BLACKLIST = (
'swin.*',
'vit.*',
)
CI_BLACKLIST = ('regnet_y_128gf',)
class TestModelMeta(type):
@classmethod
def __prepare__(mcls, name, bases):
d = dict()
test_classes = collect_torchvision_models()
for test_class in test_classes:
test_name = f'test_torchvision_model_{test_class.__name__}'
for eliminate_dead_graph in (False, True):
simple_test_name = test_name + '_simple'
if eliminate_dead_graph:
simple_test_name += '_edg'
d[simple_test_name] = mcls.build_model_test(test_class, eliminate_dead_graph)
test_classes = collect_custom_models()
for test_class in test_classes:
test_name = f'test_custom_model_{test_class.__name__}'
for eliminate_dead_graph in (False, True):
simple_test_name = test_name + '_simple'
if eliminate_dead_graph:
simple_test_name += '_edg'
d[simple_test_name] = mcls.build_model_test(test_class, eliminate_dead_graph)
return d
@classmethod
def build_model_test(cls, model_class, eliminate_dead_graph):
def f(self):
model_name = model_class.__name__
model_file = model_name
model_file += '_simple'
if eliminate_dead_graph:
model_file += '_edg'
for item in BLACKLIST:
if re.match(item, model_name):
raise unittest.SkipTest('IN BLACKLIST')
if IS_CI:
for item in CI_BLACKLIST:
if re.match(item, model_name):
raise unittest.SkipTest('IN CI BLACKLIST')
if os.path.exists(f'out/{model_file}.py'):
raise unittest.SkipTest('TESTED')
with model_tracer():
m = model_class()
m.eval()
inputs = prepare_inputs(m)
graph = trace(m, inputs, eliminate_dead_graph=eliminate_dead_graph)
self.assertTrue(
graph.generate_code(f'out/{model_file}.py', f'out/{model_file}.pth', model_name, check=True)
)
# Remove the weights file to save space
os.unlink(f'out/{model_file}.pth')
if IS_CI:
del m
del graph
del inputs
gc.collect()
return f
class TestModel(unittest.TestCase, metaclass=TestModelMeta):
pass
if __name__ == '__main__':
unittest.main()