Skip to content

Commit

Permalink
Add is_subtype argument for metatype register (#2611)
Browse files Browse the repository at this point in the history
### Changes

Add `is_subtype` argument for meta type register 

### Reason for changes

Registering a nonlinear hierarchy of sub meta types.
  • Loading branch information
AlexanderDokuchaev authored Apr 5, 2024
1 parent ec497ce commit 7105e8c
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 34 deletions.
21 changes: 11 additions & 10 deletions nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ def __init__(self, name: str):
super().__init__(name)
self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {}

def register(self, name: Optional[str] = None) -> Callable[..., Type[OperatorMetatype]]:
def register(self, name: Optional[str] = None, is_subtype: bool = False) -> Callable[..., Type[OperatorMetatype]]:
"""
Decorator for registering operator metatypes.
:param name: The registration name.
:param is_subtype: Whether the decorated metatype is a subtype of another registered operator.
:return: The inner function for registering operator metatypes.
"""
name_ = name
Expand All @@ -100,15 +101,15 @@ def wrap(obj: Type[OperatorMetatype]) -> Type[OperatorMetatype]:
if cls_name is None:
cls_name = obj.__name__
super_register(obj, cls_name)
op_names = obj.get_all_aliases()
for name in op_names:
if name in self._op_name_to_op_meta_dict and not obj.subtype_check(self._op_name_to_op_meta_dict[name]):
raise nncf.InternalError(
"Inconsistent operator metatype registry - single patched "
"op name maps to multiple metatypes!"
)

self._op_name_to_op_meta_dict[name] = obj
if not is_subtype:
op_names = obj.get_all_aliases()
for name in op_names:
if name in self._op_name_to_op_meta_dict:
raise nncf.InternalError(
"Inconsistent operator metatype registry - single patched "
f"op name `{name}` maps to multiple metatypes!"
)
self._op_name_to_op_meta_dict[name] = obj
return obj

return wrap
Expand Down
6 changes: 3 additions & 3 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ONNXOpWithWeightsMetatype(ONNXOpMetatype):
bias_port_id: Optional[int] = None


@ONNX_OPERATION_METATYPES.register()
@ONNX_OPERATION_METATYPES.register(is_subtype=True)
class ONNXDepthwiseConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "DepthwiseConvOp"
op_names = ["Conv"]
Expand All @@ -86,7 +86,7 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
return _is_depthwise_conv(model, node)


@ONNX_OPERATION_METATYPES.register()
@ONNX_OPERATION_METATYPES.register(is_subtype=True)
class ONNXGroupConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "GroupConvOp"
op_names = ["Conv"]
Expand Down Expand Up @@ -420,7 +420,7 @@ class ONNXReciprocalMetatype(ONNXOpMetatype):
hw_config_names = [HWConfigOpName.POWER]


@ONNX_OPERATION_METATYPES.register()
@ONNX_OPERATION_METATYPES.register(is_subtype=True)
class ONNXEmbeddingMetatype(ONNXOpMetatype):
name = "EmbeddingOp"
hw_config_names = [HWConfigOpName.EMBEDDING]
Expand Down
4 changes: 2 additions & 2 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class OVConvolutionBackpropDataMetatype(OVOpMetatype):
output_channel_axis = 1


@OV_OPERATOR_METATYPES.register()
@OV_OPERATOR_METATYPES.register(is_subtype=True)
class OVDepthwiseConvolutionMetatype(OVOpMetatype):
name = "DepthwiseConvolutionOp"
op_names = ["GroupConvolution"]
Expand Down Expand Up @@ -410,7 +410,7 @@ class OVLogicalXorMetatype(OVOpMetatype):
hw_config_names = [HWConfigOpName.LOGICALXOR]


@OV_OPERATOR_METATYPES.register()
@OV_OPERATOR_METATYPES.register(is_subtype=True)
class OVEmbeddingMetatype(OVOpMetatype):
name = "EmbeddingOp"
hw_config_names = [HWConfigOpName.EMBEDDING]
Expand Down
6 changes: 3 additions & 3 deletions nncf/tensorflow/graph/metatypes/keras_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_all_aliases(cls) -> List[str]:
return [cls.name]


@KERAS_LAYER_METATYPES.register()
@KERAS_LAYER_METATYPES.register(is_subtype=True)
class TFDepthwiseConv1DSubLayerMetatype(TFLayerWithWeightsMetatype):
name = "DepthwiseConv1D(Conv1DKerasLayer)"
keras_layer_names = ["Conv1D", "Convolution1D"]
Expand All @@ -112,7 +112,7 @@ class TFConv1DLayerMetatype(TFLayerWithWeightsMetatype):
bias_attr_name = "bias"


@KERAS_LAYER_METATYPES.register()
@KERAS_LAYER_METATYPES.register(is_subtype=True)
class TFDepthwiseConv2DSubLayerMetatype(TFLayerWithWeightsMetatype):
name = "DepthwiseConv2D(Conv2DKerasLayer)"
keras_layer_names = ["Conv2D", "Convolution2D"]
Expand All @@ -137,7 +137,7 @@ class TFConv2DLayerMetatype(TFLayerWithWeightsMetatype):
bias_attr_name = "bias"


@KERAS_LAYER_METATYPES.register()
@KERAS_LAYER_METATYPES.register(is_subtype=True)
class TFDepthwiseConv3DSubLayerMetatype(TFLayerWithWeightsMetatype):
name = "DepthwiseConv3D(Conv3DKerasLayer)"
keras_layer_names = ["Conv3D", "Convolution3D"]
Expand Down
32 changes: 16 additions & 16 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class PTNoopMetatype(PTOperatorMetatype):
}


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTDepthwiseConv1dSubtype(PTDepthwiseConvOperatorSubtype):
name = "Conv1DOp"
hw_config_name = [HWConfigOpName.DEPTHWISECONVOLUTION]
Expand All @@ -178,7 +178,7 @@ class PTDepthwiseConv1dSubtype(PTDepthwiseConvOperatorSubtype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleConv1dMetatype(PTModuleOperatorSubtype):
name = "Conv1DOp"
hw_config_names = [HWConfigOpName.CONVOLUTION]
Expand All @@ -202,7 +202,7 @@ class PTConv1dMetatype(PTOperatorMetatype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype):
name = "Conv2DOp"
hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION]
Expand All @@ -213,7 +213,7 @@ class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleConv2dMetatype(PTModuleOperatorSubtype):
name = "Conv2DOp"
hw_config_names = [HWConfigOpName.CONVOLUTION]
Expand All @@ -237,7 +237,7 @@ class PTConv2dMetatype(PTOperatorMetatype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTDepthwiseConv3dSubtype(PTDepthwiseConvOperatorSubtype):
name = "Conv3DOp"
hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION]
Expand All @@ -248,7 +248,7 @@ class PTDepthwiseConv3dSubtype(PTDepthwiseConvOperatorSubtype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleConv3dMetatype(PTModuleOperatorSubtype):
name = "Conv3DOp"
hw_config_names = [HWConfigOpName.CONVOLUTION]
Expand All @@ -272,7 +272,7 @@ class PTConv3dMetatype(PTOperatorMetatype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleConvTranspose1dMetatype(PTModuleOperatorSubtype):
name = "ConvTranspose1DOp"
hw_config_names = [HWConfigOpName.CONVOLUTION]
Expand All @@ -295,7 +295,7 @@ class PTConvTranspose1dMetatype(PTOperatorMetatype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleConvTranspose2dMetatype(PTModuleOperatorSubtype):
name = "ConvTranspose2DOp"
hw_config_names = [HWConfigOpName.CONVOLUTION]
Expand All @@ -318,7 +318,7 @@ class PTConvTranspose2dMetatype(PTOperatorMetatype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleConvTranspose3dMetatype(PTModuleOperatorSubtype):
name = "ConvTranspose3DOp"
hw_config_names = [HWConfigOpName.CONVOLUTION]
Expand All @@ -341,7 +341,7 @@ class PTConvTranspose3dMetatype(PTOperatorMetatype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleDeformConv2dMetatype(PTModuleOperatorSubtype):
name = "DeformConv2dOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["deform_conv2d"]}
Expand All @@ -358,7 +358,7 @@ class PTDeformConv2dMetatype(PTOperatorMetatype):
weight_port_ids = [2]


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleLinearMetatype(PTModuleOperatorSubtype):
name = "LinearOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["linear"]}
Expand Down Expand Up @@ -428,7 +428,7 @@ class PTLeakyRELUMetatype(PTOperatorMetatype):
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleLayerNormMetatype(PTModuleOperatorSubtype):
name = "LayerNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["layer_norm"]}
Expand All @@ -445,7 +445,7 @@ class PTLayerNormMetatype(PTOperatorMetatype):
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleGroupNormMetatype(PTModuleOperatorSubtype):
name = "GroupNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["group_norm"]}
Expand Down Expand Up @@ -630,7 +630,7 @@ class PTThresholdMetatype(PTOperatorMetatype):
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["threshold"]}


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleBatchNormMetatype(PTModuleOperatorSubtype):
name = "BatchNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]}
Expand Down Expand Up @@ -821,7 +821,7 @@ class PTExpandAsMetatype(PTOperatorMetatype):
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["expand_as"]}


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleEmbeddingMetatype(PTModuleOperatorSubtype):
name = "EmbeddingOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["embedding"]}
Expand All @@ -838,7 +838,7 @@ class PTEmbeddingMetatype(PTOperatorMetatype):
weight_port_ids = [1]


@PT_OPERATOR_METATYPES.register()
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleEmbeddingBagMetatype(PTModuleOperatorSubtype):
name = "EmbeddingBagOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["embedding_bag"]}
Expand Down

0 comments on commit 7105e8c

Please sign in to comment.