Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT] detach tensor in get_const_data #3199

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph
from nncf.torch.model_graph_manager import get_const_data
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_module_by_name
from nncf.torch.model_graph_manager import split_const_name
Expand Down Expand Up @@ -173,10 +174,8 @@ def get_weight(
) -> Tensor:
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
weight_name = weight_node.layer_attributes.name
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)
if weight is None or not isinstance(weight, torch.nn.Parameter):
weight = get_const_data(weight_node, model)
if weight is None:
raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.")

return Tensor(weight)
Expand Down Expand Up @@ -222,10 +221,8 @@ def transform_model(

weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
weight_name = weight_node.layer_attributes.name
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)
if weight is None or not isinstance(weight, torch.nn.Parameter):
weight = get_const_data(weight_node, model)
if weight is None:
raise nncf.InternalError(f"Could not find a torch.nn.Parameter in the model by name {weight_name}.")

# calculates compressed weights and decompression parameters
Expand Down Expand Up @@ -264,7 +261,14 @@ def transform_model(
packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data)

# sets compressed tensor
# TODO:(AlexanderDokuchaev): update set_const_data
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)
if not isinstance(weight, torch.nn.Parameter):
raise nncf.InternalError(f"Weight is not a torch.nn.Parameter in the model by name {weight_name}.")

setattr(module, weight_attr_name, compressed_parameter)

consumer_nodes = graph.get_next_nodes(weight_node)
Expand Down
6 changes: 3 additions & 3 deletions nncf/torch/model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Mod

def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor:
"""
Retrieves a constant tensor associated with a given node.
Retrieves a detached constant tensor associated with a given node.

:param const_node: The node associated with const data.
:param model: The NNCFNetwork object.
Expand All @@ -128,8 +128,8 @@ def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor:
module = get_module_by_name(module_name, model)
data = getattr(module, const_attr_name)
if isinstance(data, torch.nn.Parameter):
return data.data
return data
return data.data.detach()
return data.detach()


def get_const_data_on_port(node: NNCFNode, port_id: int, model: NNCFNetwork) -> torch.Tensor:
Expand Down
5 changes: 5 additions & 0 deletions tests/torch/test_model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,15 @@ def test_get_set_const_data():
graph = model.nncf.get_graph()
const_node = graph.get_node_by_name("conv.bias")

assert model.conv.bias.requires_grad

data = get_const_data(const_node, model)
assert not data.requires_grad
assert torch.all(model.conv.bias.data == data)

set_const_data(torch.ones_like(data), const_node, model)
assert torch.all(model.conv.bias.data == torch.ones_like(data))
assert model.conv.bias.requires_grad


@pytest.mark.parametrize(
Expand Down