diff --git a/nncf/common/graph/operator_metatypes.py b/nncf/common/graph/operator_metatypes.py index 27fc60213a4..100f510428a 100644 --- a/nncf/common/graph/operator_metatypes.py +++ b/nncf/common/graph/operator_metatypes.py @@ -79,11 +79,12 @@ def __init__(self, name: str): super().__init__(name) self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {} - def register(self, name: Optional[str] = None) -> Callable[..., Type[OperatorMetatype]]: + def register(self, name: Optional[str] = None, is_subtype: bool = False) -> Callable[..., Type[OperatorMetatype]]: """ Decorator for registering operator metatypes. :param name: The registration name. + :param is_subtype: Whether the decorated metatype is a subtype of another registered operator. :return: The inner function for registering operator metatypes. """ name_ = name @@ -100,15 +101,15 @@ def wrap(obj: Type[OperatorMetatype]) -> Type[OperatorMetatype]: if cls_name is None: cls_name = obj.__name__ super_register(obj, cls_name) - op_names = obj.get_all_aliases() - for name in op_names: - if name in self._op_name_to_op_meta_dict and not obj.subtype_check(self._op_name_to_op_meta_dict[name]): - raise nncf.InternalError( - "Inconsistent operator metatype registry - single patched " - "op name maps to multiple metatypes!" - ) - - self._op_name_to_op_meta_dict[name] = obj + if not is_subtype: + op_names = obj.get_all_aliases() + for name in op_names: + if name in self._op_name_to_op_meta_dict: + raise nncf.InternalError( + "Inconsistent operator metatype registry - single patched " + f"op name `{name}` maps to multiple metatypes!" + ) + self._op_name_to_op_meta_dict[name] = obj return obj return wrap diff --git a/nncf/onnx/graph/metatypes/onnx_metatypes.py b/nncf/onnx/graph/metatypes/onnx_metatypes.py index d7e984b8d44..2105f31a216 100644 --- a/nncf/onnx/graph/metatypes/onnx_metatypes.py +++ b/nncf/onnx/graph/metatypes/onnx_metatypes.py @@ -71,7 +71,7 @@ class ONNXOpWithWeightsMetatype(ONNXOpMetatype): bias_port_id: Optional[int] = None -@ONNX_OPERATION_METATYPES.register() +@ONNX_OPERATION_METATYPES.register(is_subtype=True) class ONNXDepthwiseConvolutionMetatype(ONNXOpWithWeightsMetatype): name = "DepthwiseConvOp" op_names = ["Conv"] @@ -86,7 +86,7 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool: return _is_depthwise_conv(model, node) -@ONNX_OPERATION_METATYPES.register() +@ONNX_OPERATION_METATYPES.register(is_subtype=True) class ONNXGroupConvolutionMetatype(ONNXOpWithWeightsMetatype): name = "GroupConvOp" op_names = ["Conv"] @@ -420,7 +420,7 @@ class ONNXReciprocalMetatype(ONNXOpMetatype): hw_config_names = [HWConfigOpName.POWER] -@ONNX_OPERATION_METATYPES.register() +@ONNX_OPERATION_METATYPES.register(is_subtype=True) class ONNXEmbeddingMetatype(ONNXOpMetatype): name = "EmbeddingOp" hw_config_names = [HWConfigOpName.EMBEDDING] diff --git a/nncf/openvino/graph/metatypes/openvino_metatypes.py b/nncf/openvino/graph/metatypes/openvino_metatypes.py index 3600ef7e71b..eb806ddfffa 100644 --- a/nncf/openvino/graph/metatypes/openvino_metatypes.py +++ b/nncf/openvino/graph/metatypes/openvino_metatypes.py @@ -70,7 +70,7 @@ class OVConvolutionBackpropDataMetatype(OVOpMetatype): output_channel_axis = 1 -@OV_OPERATOR_METATYPES.register() +@OV_OPERATOR_METATYPES.register(is_subtype=True) class OVDepthwiseConvolutionMetatype(OVOpMetatype): name = "DepthwiseConvolutionOp" op_names = ["GroupConvolution"] @@ -410,7 +410,7 @@ class OVLogicalXorMetatype(OVOpMetatype): hw_config_names = [HWConfigOpName.LOGICALXOR] -@OV_OPERATOR_METATYPES.register() +@OV_OPERATOR_METATYPES.register(is_subtype=True) class OVEmbeddingMetatype(OVOpMetatype): name = "EmbeddingOp" hw_config_names = [HWConfigOpName.EMBEDDING] diff --git a/nncf/tensorflow/graph/metatypes/keras_layers.py b/nncf/tensorflow/graph/metatypes/keras_layers.py index ce65fc63298..8f6636783bb 100644 --- a/nncf/tensorflow/graph/metatypes/keras_layers.py +++ b/nncf/tensorflow/graph/metatypes/keras_layers.py @@ -87,7 +87,7 @@ def get_all_aliases(cls) -> List[str]: return [cls.name] -@KERAS_LAYER_METATYPES.register() +@KERAS_LAYER_METATYPES.register(is_subtype=True) class TFDepthwiseConv1DSubLayerMetatype(TFLayerWithWeightsMetatype): name = "DepthwiseConv1D(Conv1DKerasLayer)" keras_layer_names = ["Conv1D", "Convolution1D"] @@ -112,7 +112,7 @@ class TFConv1DLayerMetatype(TFLayerWithWeightsMetatype): bias_attr_name = "bias" -@KERAS_LAYER_METATYPES.register() +@KERAS_LAYER_METATYPES.register(is_subtype=True) class TFDepthwiseConv2DSubLayerMetatype(TFLayerWithWeightsMetatype): name = "DepthwiseConv2D(Conv2DKerasLayer)" keras_layer_names = ["Conv2D", "Convolution2D"] @@ -137,7 +137,7 @@ class TFConv2DLayerMetatype(TFLayerWithWeightsMetatype): bias_attr_name = "bias" -@KERAS_LAYER_METATYPES.register() +@KERAS_LAYER_METATYPES.register(is_subtype=True) class TFDepthwiseConv3DSubLayerMetatype(TFLayerWithWeightsMetatype): name = "DepthwiseConv3D(Conv3DKerasLayer)" keras_layer_names = ["Conv3D", "Convolution3D"] diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 0f2b9f9fdb1..b99bab5ebab 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -167,7 +167,7 @@ class PTNoopMetatype(PTOperatorMetatype): } -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTDepthwiseConv1dSubtype(PTDepthwiseConvOperatorSubtype): name = "Conv1DOp" hw_config_name = [HWConfigOpName.DEPTHWISECONVOLUTION] @@ -178,7 +178,7 @@ class PTDepthwiseConv1dSubtype(PTDepthwiseConvOperatorSubtype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleConv1dMetatype(PTModuleOperatorSubtype): name = "Conv1DOp" hw_config_names = [HWConfigOpName.CONVOLUTION] @@ -202,7 +202,7 @@ class PTConv1dMetatype(PTOperatorMetatype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype): name = "Conv2DOp" hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION] @@ -213,7 +213,7 @@ class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleConv2dMetatype(PTModuleOperatorSubtype): name = "Conv2DOp" hw_config_names = [HWConfigOpName.CONVOLUTION] @@ -237,7 +237,7 @@ class PTConv2dMetatype(PTOperatorMetatype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTDepthwiseConv3dSubtype(PTDepthwiseConvOperatorSubtype): name = "Conv3DOp" hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION] @@ -248,7 +248,7 @@ class PTDepthwiseConv3dSubtype(PTDepthwiseConvOperatorSubtype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleConv3dMetatype(PTModuleOperatorSubtype): name = "Conv3DOp" hw_config_names = [HWConfigOpName.CONVOLUTION] @@ -272,7 +272,7 @@ class PTConv3dMetatype(PTOperatorMetatype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleConvTranspose1dMetatype(PTModuleOperatorSubtype): name = "ConvTranspose1DOp" hw_config_names = [HWConfigOpName.CONVOLUTION] @@ -295,7 +295,7 @@ class PTConvTranspose1dMetatype(PTOperatorMetatype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleConvTranspose2dMetatype(PTModuleOperatorSubtype): name = "ConvTranspose2DOp" hw_config_names = [HWConfigOpName.CONVOLUTION] @@ -318,7 +318,7 @@ class PTConvTranspose2dMetatype(PTOperatorMetatype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleConvTranspose3dMetatype(PTModuleOperatorSubtype): name = "ConvTranspose3DOp" hw_config_names = [HWConfigOpName.CONVOLUTION] @@ -341,7 +341,7 @@ class PTConvTranspose3dMetatype(PTOperatorMetatype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleDeformConv2dMetatype(PTModuleOperatorSubtype): name = "DeformConv2dOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["deform_conv2d"]} @@ -358,7 +358,7 @@ class PTDeformConv2dMetatype(PTOperatorMetatype): weight_port_ids = [2] -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleLinearMetatype(PTModuleOperatorSubtype): name = "LinearOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["linear"]} @@ -428,7 +428,7 @@ class PTLeakyRELUMetatype(PTOperatorMetatype): num_expected_input_edges = 1 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleLayerNormMetatype(PTModuleOperatorSubtype): name = "LayerNormOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["layer_norm"]} @@ -445,7 +445,7 @@ class PTLayerNormMetatype(PTOperatorMetatype): num_expected_input_edges = 1 -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleGroupNormMetatype(PTModuleOperatorSubtype): name = "GroupNormOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["group_norm"]} @@ -630,7 +630,7 @@ class PTThresholdMetatype(PTOperatorMetatype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["threshold"]} -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleBatchNormMetatype(PTModuleOperatorSubtype): name = "BatchNormOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]} @@ -821,7 +821,7 @@ class PTExpandAsMetatype(PTOperatorMetatype): module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["expand_as"]} -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleEmbeddingMetatype(PTModuleOperatorSubtype): name = "EmbeddingOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["embedding"]} @@ -838,7 +838,7 @@ class PTEmbeddingMetatype(PTOperatorMetatype): weight_port_ids = [1] -@PT_OPERATOR_METATYPES.register() +@PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleEmbeddingBagMetatype(PTModuleOperatorSubtype): name = "EmbeddingBagOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["embedding_bag"]}