Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Apr 10, 2024
1 parent 4f6ea87 commit a09e21b
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 34 deletions.
41 changes: 39 additions & 2 deletions nncf/torch/model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,47 @@ def get_fused_bias_value(node: NNCFNode, model: NNCFNetwork) -> Optional[torch.T
:return: The bias value that is applied to the output tensor of the node's operation.
"""
nncf_graph = model.nncf.get_graph()

fused_node = get_potential_fused_node(node.node_name, nncf_graph)
target_node_name = fused_node.node_name if fused_node else node.node_name

bias = get_const_data_on_port(node, node.metatype.bias_port_id, model)

if fused_node is None:
return bias

fused_bias = get_const_data_on_port(fused_node, fused_node.metatype.bias_port_id, model)
fused_weight = get_const_data_on_port(fused_node, fused_node.metatype.weight_port_ids[0], model)
if bias is None:
return fused_bias

return bias * fused_weight + fused_bias


def update_fused_bias(target_node_name: str, new_bias: torch.Tensor, model: NNCFNetwork) -> None:
"""
Update bias for target module or potential fused module.
:param target_node_name: The target node name.
:param new_bias: New bias value.
:param model: The model.
"""
nncf_graph = model.nncf.get_graph()
target_node = nncf_graph.get_node_by_name(target_node_name)
return get_const_data_on_port(target_node, target_node.metatype.bias_port_id, model)
fused_node = get_potential_fused_node(target_node_name, nncf_graph)
if fused_node is None:
set_const_data_to_port_id(new_bias, target_node, target_node.metatype.bias_port_id, model)
return

target_bias_node = get_const_node(target_node, target_node.metatype.bias_port_id, nncf_graph)
fused_bias_node = get_const_node(fused_node, fused_node.metatype.bias_port_id, nncf_graph)
fused_weight_node = get_const_node(fused_node, fused_node.metatype.weight_port_ids[0], nncf_graph)

if target_bias_node is None:
set_const_data(new_bias, fused_bias_node, model)
return

new_bias = new_bias - get_const_data(target_bias_node, model) * get_const_data(fused_weight_node, model)
set_const_data(new_bias, fused_bias_node, model)


def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[int]:
Expand Down
33 changes: 1 addition & 32 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.model_graph_manager import get_const_data
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_potential_fused_node
from nncf.torch.model_graph_manager import set_const_data
from nncf.torch.model_graph_manager import set_const_data_to_port_id
from nncf.torch.model_graph_manager import update_fused_bias
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
Expand Down Expand Up @@ -245,33 +241,6 @@ def _apply_weights_update_transformations(
return model


def update_fused_bias(target_node_name: str, new_bias: Tensor, model: NNCFNetwork) -> None:
"""
Update bias for target module or potential fused module.
:param target_node_name: The target node name.
:param new_bias: New bias value.
:param model: The model.
"""
nncf_graph = model.nncf.get_graph()
target_node = nncf_graph.get_node_by_name(target_node_name)
fused_node = get_potential_fused_node(target_node_name, nncf_graph)
if fused_node is None:
set_const_data_to_port_id(new_bias, target_node, target_node.metatype.bias_port_id, model)
return

target_bias_node = get_const_node(target_node, target_node.metatype.bias_port_id, nncf_graph)
fused_bias_node = get_const_node(fused_node, fused_node.metatype.bias_port_id, nncf_graph)
fused_weight_node = get_const_node(fused_node, fused_node.metatype.weight_port_ids[0], nncf_graph)

if target_bias_node is None:
set_const_data(new_bias, fused_bias_node, model)
return

new_bias = new_bias - get_const_data(target_bias_node, model) * get_const_data(fused_weight_node, model)
set_const_data(new_bias, fused_bias_node, model)


def update_parameter(target_node_name: str, parameter_name: str, new_value: Tensor, model: NNCFNetwork) -> None:
"""
Update parameter for target module.
Expand Down
49 changes: 49 additions & 0 deletions tests/torch/test_model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
from nncf.torch.model_graph_manager import get_const_data_on_port
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_fake_quantizer
from nncf.torch.model_graph_manager import get_fused_bias_value
from nncf.torch.model_graph_manager import get_module_by_name
from nncf.torch.model_graph_manager import get_potential_fused_node
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.model_graph_manager import is_node_with_fused_bias
from nncf.torch.model_graph_manager import is_quantized_weights
from nncf.torch.model_graph_manager import set_const_data
from nncf.torch.model_graph_manager import split_const_name
from nncf.torch.model_graph_manager import update_fused_bias
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.model_transformer import PTTransformationLayout
from nncf.torch.nncf_network import NNCFNetwork
Expand Down Expand Up @@ -313,3 +315,50 @@ def test_is_quantized_weights():
q_graph = q_model.nncf.get_graph()
q_node = q_graph.get_node_by_name(node_name)
assert is_quantized_weights(q_node, q_graph)


@pytest.mark.parametrize(
"model_cls, ref",
(
(helpers.ConvTestModel, [0.1000, 1.0000]), # conv.bias
(helpers.ConvBNTestModel, [0.1000, 1.0000]), # bn.bias
(helpers.ConvBiasBNTestModel, [0.1600, 3.6000]), # conv.bias*bn.weight + bn.bias
),
)
def test_get_fused_bias_value(model_cls, ref):
model = wrap_model(model_cls(), torch.ones(model_cls.INPUT_SIZE), trace_parameters=True)

graph = model.nncf.get_graph()
target_node = graph.get_nodes_by_types("conv2d")[0]

bias = get_fused_bias_value(target_node, model)
assert torch.all(torch.isclose(bias, torch.tensor(ref)))


@pytest.mark.parametrize(
"model_cls",
(
(helpers.ConvTestModel), # conv.bias
(helpers.ConvBNTestModel), # bn.bias
(helpers.ConvBiasBNTestModel), # conv.bias*bn.weight + bn.bias
),
)
def test_update_fused_bias(model_cls):
model = wrap_model(model_cls(), torch.ones(model_cls.INPUT_SIZE), trace_parameters=True)
ref_new_bias = torch.tensor([-1.0, -1.0])
graph = model.nncf.get_graph()
target_node = graph.get_nodes_by_types("conv2d")[0]

update_fused_bias(target_node.node_name, ref_new_bias, model)
bias = get_fused_bias_value(target_node, model)
assert torch.all(torch.isclose(bias, ref_new_bias))

if model_cls == helpers.ConvTestModel:
assert torch.all(torch.isclose(model.conv.bias, ref_new_bias))
if model_cls == helpers.ConvBNTestModel:
assert model.conv.bias is None
assert torch.all(torch.isclose(model.bn.bias, ref_new_bias))
if model_cls == helpers.ConvBiasBNTestModel:
assert torch.all(torch.isclose(model.conv.bias, torch.tensor([0.3000, 1.3000])))
assert torch.all(torch.isclose(model.bn.bias, torch.tensor([-1.0600, -3.6000])))
assert torch.all(torch.isclose(model.conv.bias * model.bn.weight + model.bn.bias, ref_new_bias))

0 comments on commit a09e21b

Please sign in to comment.