From 5a0d5464bf3ca7ac76870912d08f0111088631bf Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 31 May 2024 20:04:21 +0200 Subject: [PATCH] WIP --- .../torch_fx/nncf_graph_builder.py | 38 ++++++++++++++++++- .../torch_fx/quantization/quantize_model.py | 16 +++++--- torch_compile_ex_release.py | 25 ++++++------ 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/nncf/experimental/torch_fx/nncf_graph_builder.py b/nncf/experimental/torch_fx/nncf_graph_builder.py index b8915a3e7a6..65bde1ffa26 100644 --- a/nncf/experimental/torch_fx/nncf_graph_builder.py +++ b/nncf/experimental/torch_fx/nncf_graph_builder.py @@ -77,7 +77,7 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe return node_type, node_metatype @staticmethod - def _separate_conv_and_bias(model: torch.fx.GraphModule): + def separate_conv_and_bias(model: torch.fx.GraphModule): """ Separates one joined conv+bias node to two nodes: conv and bias. Needed as nncf does not expect joined conv @@ -122,6 +122,40 @@ def _separate_conv_and_bias(model: torch.fx.GraphModule): model.graph.eliminate_dead_code() model.recompile() + @staticmethod + def merge_conv_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined conv+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_targets = (torch.ops.aten.add_.Tensor,) + for n in model.graph.nodes: + if not _is_conv(n): + continue + if len(n.args) > 2 and n.args[2] is not None: + continue + bias_node = next(iter(n.users)) + if len(n.users) > 1 or bias_node.target not in add_node_targets: + continue + conv_node = n + const_node = None + for node in bias_node.all_input_nodes: + if node is not conv_node: + const_node = node + break + assert const_node is not None + bias_value = _get_tensor_constant_from_node(const_node, model).squeeze() + with model.graph.inserting_before(conv_node): + new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value) + args = list(conv_node.args) + args[2] = new_bias_node + conv_node.args = tuple(args) + for user in list(bias_node.users): + user.replace_input_with(bias_node, conv_node) + + model.graph.eliminate_dead_code() + model.recompile() + @staticmethod def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: """ @@ -136,7 +170,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: _fuse_conv_bn_(model) # BN fuses to conv bias, conv+bias joined op # needs to be splited for nncf - GraphConverter._separate_conv_and_bias(model) + GraphConverter.separate_conv_and_bias(model) nncf_graph = PTNNCFGraph() diff --git a/nncf/experimental/torch_fx/quantization/quantize_model.py b/nncf/experimental/torch_fx/quantization/quantize_model.py index bbb638b10a7..abc0c9ddfaf 100644 --- a/nncf/experimental/torch_fx/quantization/quantize_model.py +++ b/nncf/experimental/torch_fx/quantization/quantize_model.py @@ -78,19 +78,23 @@ def quantize_impl( nncf_graph = NNCFGraphFactory.create(copied_model) quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) + from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter + + GraphConverter.merge_conv_and_bias(quantized_model) + # Magic. Without this call compiled model # is not preformant - model = GraphModule(model, model.graph) + quantized_model = GraphModule(quantized_model, quantized_model.graph) - model = _fold_conv_bn_qat(model) + quantized_model = _fold_conv_bn_qat(quantized_model) pm = PassManager([DuplicateDQPass()]) - model = pm(model).graph_module + quantized_model = pm(quantized_model).graph_module pm = PassManager([PortNodeMetaForQDQ()]) - model = pm(model).graph_module + quantized_model = pm(quantized_model).graph_module - model.meta.update(original_graph_meta) - model = _disallow_eval_train(model) + quantized_model.meta.update(original_graph_meta) + quantized_model = _disallow_eval_train(quantized_model) return quantized_model diff --git a/torch_compile_ex_release.py b/torch_compile_ex_release.py index 85f9ef74738..78d66f3b7fc 100644 --- a/torch_compile_ex_release.py +++ b/torch_compile_ex_release.py @@ -44,7 +44,7 @@ def get_exported_model_from_nn_module(module, example_inputs): return capture_pre_autograd_graph(module, example_inputs) -NNCF_IMPL = True +NNCF_IMPL = False def get_qsetup(exported_model, example_inputs): @@ -79,8 +79,6 @@ def get_qsetup(exported_model, example_inputs): def quantize(model, example_inputs): - exported_model = get_exported_model_from_nn_module(model, example_inputs) - if NNCF_IMPL: # Use NNCF here on exported model # to create a quantized model which is compatible with @@ -97,19 +95,18 @@ def quantize(model, example_inputs): import nncf calibration_dataset = nncf.Dataset(example_inputs) + exported_model = get_exported_model_from_nn_module(model, example_inputs) quantized_model = nncf.quantize(exported_model, calibration_dataset) g = FxGraphDrawer(quantized_model, "resnet18_quantized_native_nncf") g.get_dot_graph().write_svg("resnet18_quantized_native_nncf.svg") return quantized_model else: - - g = FxGraphDrawer(exported_model, "resnet18") - g.get_dot_graph().write_svg("resnet18_compiled.svg") - nncf_graph = GraphConverter.create_nncf_graph(exported_model) - del nncf_graph + # g = FxGraphDrawer(exported_model, "resnet18") + # g.get_dot_graph().write_svg("resnet18_compiled.svg") # MOCK NNCF QUANTIZATION + exported_model = get_exported_model_from_nn_module(model, example_inputs) qsetup = get_qsetup(exported_model, example_inputs) exported_model = get_exported_model_from_nn_module(model, example_inputs) exported_model = insert_qdq_to_model(exported_model, qsetup) @@ -166,13 +163,13 @@ def main(model_name, num_iters): converted_model = quantize(copy.deepcopy(model), example_inputs) - print("original model execution time: ", measure_time(model, example_inputs, num_iters)) + # print("original model execution time: ", measure_time(model, example_inputs, num_iters)) - native_optimized_model_fp32 = torch.compile(model) - print( - "Torch Inductor FP32 model execution time: ", - measure_time(native_optimized_model_fp32, example_inputs, num_iters), - ) + # native_optimized_model_fp32 = torch.compile(model) + # print( + # "Torch Inductor FP32 model execution time: ", + # measure_time(native_optimized_model_fp32, example_inputs, num_iters), + # ) native_optimized_model_int8 = torch.compile(converted_model) print(