Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Apr 9, 2024
1 parent b2f9080 commit 2c95b0f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
4 changes: 4 additions & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ class PTModuleLayerNormMetatype(PTModuleOperatorSubtype):
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["layer_norm"]}
hw_config_names = [HWConfigOpName.MVN]
num_expected_input_edges = 1
weight_port_ids = [2]


@PT_OPERATOR_METATYPES.register()
Expand All @@ -496,13 +497,15 @@ class PTLayerNormMetatype(PTOperatorMetatype):
hw_config_names = [HWConfigOpName.MVN]
subtypes = [PTModuleLayerNormMetatype]
num_expected_input_edges = 1
weight_port_ids = [2]


@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleGroupNormMetatype(PTModuleOperatorSubtype):
name = "GroupNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["group_norm"]}
hw_config_names = [HWConfigOpName.MVN]
weight_port_ids = [2]


@PT_OPERATOR_METATYPES.register()
Expand All @@ -511,6 +514,7 @@ class PTGroupNormMetatype(PTOperatorMetatype):
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["group_norm"]}
hw_config_names = [HWConfigOpName.MVN]
subtypes = [PTModuleGroupNormMetatype]
weight_port_ids = [2]


@PT_OPERATOR_METATYPES.register()
Expand Down
5 changes: 5 additions & 0 deletions tests/torch/test_layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.common.graph.layer_attributes import PermuteLayerAttributes
from nncf.common.graph.layer_attributes import ReshapeLayerAttributes
from nncf.common.graph.layer_attributes import TransposeLayerAttributes
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.torch import wrap_model
from nncf.torch.dynamic_graph.graph_tracer import create_dummy_forward_fn
Expand Down Expand Up @@ -552,6 +553,10 @@ def test_can_set_valid_layer_attributes_wrap_model(desc: LayerAttributesTestDesc
]
assert ref_values == actual_values

if isinstance(desc.layer_attributes, WeightedLayerAttributes):
assert hasattr(desc.metatype_cls, "weight_port_ids")
assert len(desc.metatype_cls.weight_port_ids) > 0


@pytest.mark.parametrize(
"signature, args, kwargs",
Expand Down

0 comments on commit 2c95b0f

Please sign in to comment.