diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index e57152358b6..31e5c366f46 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -38,6 +38,7 @@ from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph +from nncf.torch.model_graph_manager import get_const_data from nncf.torch.model_graph_manager import get_const_node from nncf.torch.model_graph_manager import get_module_by_name from nncf.torch.model_graph_manager import split_const_name @@ -173,10 +174,8 @@ def get_weight( ) -> Tensor: weight_node = get_const_node(node_with_weight, weight_port_id, graph) weight_name = weight_node.layer_attributes.name - module_name, weight_attr_name = split_const_name(weight_name) - module = get_module_by_name(module_name, model) - weight = getattr(module, weight_attr_name) - if weight is None or not isinstance(weight, torch.nn.Parameter): + weight = get_const_data(weight_node, model) + if weight is None: raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.") return Tensor(weight) @@ -222,10 +221,8 @@ def transform_model( weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph) weight_name = weight_node.layer_attributes.name - module_name, weight_attr_name = split_const_name(weight_name) - module = get_module_by_name(module_name, model) - weight = getattr(module, weight_attr_name) - if weight is None or not isinstance(weight, torch.nn.Parameter): + weight = get_const_data(weight_node, model) + if weight is None: raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.") # calculates compressed weights and decompression parameters @@ -264,7 +261,14 @@ def transform_model( packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data) # sets compressed tensor + # TODO:(AlexanderDokuchaev): update set_const_data compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False) + module_name, weight_attr_name = split_const_name(weight_name) + module = get_module_by_name(module_name, model) + weight = getattr(module, weight_attr_name) + if not isinstance(weight, torch.nn.Parameter): + raise nncf.InternalError(f"Weight is not a torch.nn.Parameter in the model by name {weight_name}.") + setattr(module, weight_attr_name, compressed_parameter) consumer_nodes = graph.get_next_nodes(weight_node) diff --git a/nncf/torch/model_graph_manager.py b/nncf/torch/model_graph_manager.py index f4973d442e6..12023346206 100644 --- a/nncf/torch/model_graph_manager.py +++ b/nncf/torch/model_graph_manager.py @@ -117,7 +117,7 @@ def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Mod def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor: """ - Retrieves a constant tensor associated with a given node. + Retrieves a detached constant tensor associated with a given node. :param const_node: The node associated with const data. :param model: The NNCFNetwork object. @@ -128,8 +128,8 @@ def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor: module = get_module_by_name(module_name, model) data = getattr(module, const_attr_name) if isinstance(data, torch.nn.Parameter): - return data.data - return data + return data.data.detach() + return data.detach() def get_const_data_on_port(node: NNCFNode, port_id: int, model: NNCFNetwork) -> torch.Tensor: diff --git a/tests/torch/test_model_graph_manager.py b/tests/torch/test_model_graph_manager.py index 27ecf241597..e841fc24e40 100644 --- a/tests/torch/test_model_graph_manager.py +++ b/tests/torch/test_model_graph_manager.py @@ -239,10 +239,15 @@ def test_get_set_const_data(): graph = model.nncf.get_graph() const_node = graph.get_node_by_name("conv.bias") + assert model.conv.bias.requires_grad + data = get_const_data(const_node, model) + assert not data.requires_grad assert torch.all(model.conv.bias.data == data) + set_const_data(torch.ones_like(data), const_node, model) assert torch.all(model.conv.bias.data == torch.ones_like(data)) + assert model.conv.bias.requires_grad @pytest.mark.parametrize(