Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Jan 23, 2025
1 parent f574a1f commit dd80542
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
20 changes: 12 additions & 8 deletions nncf/quantization/algorithms/weight_compression/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph
from nncf.torch.model_graph_manager import get_const_data
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_module_by_name
from nncf.torch.model_graph_manager import split_const_name
Expand Down Expand Up @@ -173,10 +174,8 @@ def get_weight(
) -> Tensor:
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
weight_name = weight_node.layer_attributes.name
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)
if weight is None or not isinstance(weight, torch.nn.Parameter):
weight = get_const_data(weight_node, model)
if weight is None:
raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.")

return Tensor(weight)
Expand Down Expand Up @@ -222,10 +221,8 @@ def transform_model(

weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
weight_name = weight_node.layer_attributes.name
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)
if weight is None or not isinstance(weight, torch.nn.Parameter):
weight = get_const_data(weight_node, model)
if weight is None:
raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.")

# calculates compressed weights and decompression parameters
Expand Down Expand Up @@ -264,7 +261,14 @@ def transform_model(
packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data)

# sets compressed tensor
# TODO:(AlexanderDokuchaev): update set_const_data
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)
if not isinstance(weight, torch.nn.Parameter):
raise nncf.InternalError(f"Weight is not a torch.nn.Parameter in the model by name {weight_name}.")

setattr(module, weight_attr_name, compressed_parameter)

consumer_nodes = graph.get_next_nodes(weight_node)
Expand Down
6 changes: 3 additions & 3 deletions nncf/torch/model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Mod

def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor:
"""
Retrieves a constant tensor associated with a given node.
Retrieves a detached constant tensor associated with a given node.
:param const_node: The node associated with const data.
:param model: The NNCFNetwork object.
Expand All @@ -128,8 +128,8 @@ def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor:
module = get_module_by_name(module_name, model)
data = getattr(module, const_attr_name)
if isinstance(data, torch.nn.Parameter):
return data.data
return data
return data.data.detach()
return data.detach()


def get_const_data_on_port(node: NNCFNode, port_id: int, model: NNCFNetwork) -> torch.Tensor:
Expand Down
5 changes: 5 additions & 0 deletions tests/torch/test_model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,15 @@ def test_get_set_const_data():
graph = model.nncf.get_graph()
const_node = graph.get_node_by_name("conv.bias")

assert model.conv.bias.requires_grad

data = get_const_data(const_node, model)
assert not data.requires_grad
assert torch.all(model.conv.bias.data == data)

set_const_data(torch.ones_like(data), const_node, model)
assert torch.all(model.conv.bias.data == torch.ones_like(data))
assert model.conv.bias.requires_grad


@pytest.mark.parametrize(
Expand Down

0 comments on commit dd80542

Please sign in to comment.