-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaot.py
207 lines (174 loc) · 8.36 KB
/
aot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python
import torch
import time
import pickle
import operator
from functorch.compile import aot_module, make_boxed_func
from torch._functorch.partitioners import draw_graph
from typing import List, TypeAlias
import torch.utils._pytree as pytree
import tempfile
GRAPH: torch.Graph = None
def my_compiler(fx_module: torch.fx.GraphModule, example_inputs):
torchscriptIR = torch.Graph()
fxNode_to_Value = {} # Dict[torch.fx.Node -> torch.Value]
scalar_to_Value = {} # Dict[Union[int, str, float] -> torch.Value]
serialized_to_Value = {} # Dict[bytes -> torch.Value]
the_none_Value = []
for fxNode in fx_module.graph.nodes:
assert len(fxNode.kwargs) == 0
def convert_arg_to_Value(arg):
if isinstance(arg, int):
if arg in scalar_to_Value:
return scalar_to_Value[arg]
else:
node = torchscriptIR.create("prim::Constant", 1)
node.i_("value", arg)
torchscriptIR.insertNode(node)
the_value = node.outputsAt(0)
scalar_to_Value[arg] = the_value
return the_value
elif isinstance(arg, torch.fx.Node):
return fxNode_to_Value[arg]
elif isinstance(arg, list):
elem_values = [convert_arg_to_Value(elem) for elem in arg]
elem_values_uuids = [value.unique() for value in elem_values]
serialized = pickle.dumps(elem_values_uuids)
if serialized in serialized_to_Value:
return serialized_to_Value[serialized]
else:
node = torchscriptIR.create("prim::ListConstruct", elem_values, 1)
torchscriptIR.insertNode(node)
the_value = node.outputsAt(0)
serialized_to_Value[serialized] = the_value
return the_value
elif arg is None:
if len(the_none_Value) == 0:
node = torchscriptIR.create("prim::Constant", 1)
torchscriptIR.insertNode(node)
the_value = node.outputsAt(0)
the_none_Value.append(the_value)
return the_value
else:
return the_none_Value[0]
elif isinstance(arg, torch.dtype):
source_code = "import torch\ndef t():\n return " + arg.__repr__()
t_graph = None
with tempfile.NamedTemporaryFile(mode = "w+t", suffix = ".py") as tmp_file:
tmp_file.write(source_code)
tmp_file.flush()
obj = compile(source_code, tmp_file.name, "exec")
scope = {}
exec(obj, scope)
func = scope["t"]
func.__module__ = "t"
t_graph = torch.jit.script(func)
output_values = torchscriptIR.insertGraph(t_graph.graph, [])
return output_values[0]
else:
assert False # unsupport
if fxNode.op == "call_function":
input_values = [convert_arg_to_Value(arg) for arg in fxNode.args]
if isinstance(fxNode.target, torch._ops.OpOverload):
#node = torchscriptIR.create(fxNode.target._name, input_values, 1)
node = torchscriptIR.create(fxNode.target.overloadpacket._qualified_op_name, input_values, 1)
elif fxNode.target == operator.getitem:
#node = torchscriptIR.create("aten::__getitem__", input_values, 1)
node = torchscriptIR.create("prim::TupleIndex", input_values, 1)
else:
print(fxNode.target)
assert False
torchscriptIR.insertNode(node)
the_value = node.outputsAt(0)
the_value.setDebugName(fxNode.name)
fxNode_to_Value[fxNode] = the_value
elif fxNode.op == "output":
if isinstance(fxNode.args[0], (list, tuple)):
input_values = [convert_arg_to_Value(output_fxNode) for output_fxNode in fxNode.args[0]]
for value in input_values:
torchscriptIR.registerOutput(value)
else:
assert False
elif fxNode.op == "placeholder":
input_value = torchscriptIR.addInput(fxNode.name)
fxNode_to_Value[fxNode] = input_value
else:
assert False
global GRAPH
GRAPH = torchscriptIR
return make_boxed_func(fx_module.forward)
def capture_state(module):
named_params = dict(module.named_parameters(remove_duplicate=False))
named_buffers = dict(module.named_buffers(remove_duplicate=False))
state, _ = pytree.tree_flatten((named_params, named_buffers))
return tuple(state)
def capture_forward_and_backward(module, *inputs):
global GRAPH
aot_model = aot_module(module, fw_compiler = my_compiler)
y = aot_model(*inputs)
forward = GRAPH
y.sum().backward()
backward = GRAPH
return (forward, backward)
def capture_forward(model, *inputs):
global GRAPH
aot_model = aot_module(model, fw_compiler = my_compiler)
y = aot_model(*inputs)
return GRAPH
def get_grad_graph(module_forward, module_backward, loss_forward, loss_backward):
grad_graph = torch.Graph()
inputs_name = [input_value.debugName() for input_value in module_forward.inputs()]
inputs_name.pop()
inputs_name.append("datas")
inputs_name.append("targets")
inputs_value = [grad_graph.addInput(input_name) for input_name in inputs_name]
states_value = inputs_value[0:-2]
data_value = inputs_value[-2]
target_value = inputs_value[-1]
module_forward_outputs_value = grad_graph.insertGraph(module_forward, states_value + [data_value])
pred_value = module_forward_outputs_value[0]
pred_value.setDebugName("preds")
module_backward_activations_value = module_forward_outputs_value[1:]
loss_forward_outputs_value = grad_graph.insertGraph(loss_forward, [pred_value, target_value])
loss_value = loss_forward_outputs_value[0]
loss_value.setDebugName("loss")
loss_backward_activations_value = loss_forward_outputs_value[1:]
node = grad_graph.create("prim::Constant", 1)
node.i_("value", 1)
grad_graph.insertNode(node)
const1_value = node.outputsAt(0)
loss_backward_outputs_value = grad_graph.insertGraph(loss_backward, loss_backward_activations_value + [const1_value])
pred_tangents_value = loss_backward_outputs_value[0]
module_backward_outputs_value = grad_graph.insertGraph(module_backward, module_backward_activations_value + [pred_tangents_value])
state_grads_value = module_backward_outputs_value[0:len(states_value)]
assert len(states_value) == len(state_grads_value)
for i in range(len(states_value)):
state_grads_value[i].setDebugName("grad_" + states_value[i].debugName())
outputs_value = state_grads_value + [pred_value] + [loss_value]
for output_value in outputs_value:
grad_graph.registerOutput(output_value)
return grad_graph
def get_sgd_training_graph(grad_graph, lr: float):
new_params_graph = torch.Graph()
new_params_graph_inputs_value = [new_params_graph.addInput(input_value.debugName()) for input_value in grad_graph.inputs()]
grad_graph_outputs_value = new_params_graph.insertGraph(grad_graph, new_params_graph_inputs_value)
node = new_params_graph.create("prim::Constant", 1)
node.f_("value", lr)
new_params_graph.insertNode(node)
lr_value = node.outputsAt(0)
lr_value.setDebugName("lr")
for i in range(len(grad_graph_outputs_value) - 2):
param_value = new_params_graph_inputs_value[i]
grad_value = grad_graph_outputs_value[i]
node = new_params_graph.create("aten::mul", [grad_value, lr_value], 1)
new_params_graph.insertNode(node)
detal_value = node.outputsAt(0)
detal_value.setDebugName("detal_" + param_value.debugName())
node = new_params_graph.create("aten::sub", [param_value, detal_value], 1)
new_params_graph.insertNode(node)
new_param_value = node.outputsAt(0)
new_param_value.setDebugName("new_" + param_value.debugName())
new_params_graph.registerOutput(new_param_value)
new_params_graph.registerOutput(grad_graph_outputs_value[-2])
new_params_graph.registerOutput(grad_graph_outputs_value[-1])
return new_params_graph