From 903f43f394aa053f643bf72cead0d53d69ec3af4 Mon Sep 17 00:00:00 2001 From: Mason Ma Date: Fri, 3 Feb 2023 14:31:50 +0800 Subject: [PATCH] feat: InstanceNormalization --- onnx2torch/node_converters/__init__.py | 1 + onnx2torch/node_converters/instance_norm.py | 88 +++++++++++++++++++++ operators.md | 2 +- tests/node_converters/instance_norm_test.py | 47 +++++++++++ 4 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 onnx2torch/node_converters/instance_norm.py create mode 100644 tests/node_converters/instance_norm_test.py diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index 2abcda92..3520357c 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -19,6 +19,7 @@ from onnx2torch.node_converters.gemm import * from onnx2torch.node_converters.global_average_pool import * from onnx2torch.node_converters.identity import * +from onnx2torch.node_converters.instance_norm import * from onnx2torch.node_converters.logical import * from onnx2torch.node_converters.lrn import * from onnx2torch.node_converters.matmul import * diff --git a/onnx2torch/node_converters/instance_norm.py b/onnx2torch/node_converters/instance_norm.py new file mode 100644 index 00000000..e0acb996 --- /dev/null +++ b/onnx2torch/node_converters/instance_norm.py @@ -0,0 +1,88 @@ +__all__ = [ + 'OnnxInstanceNorm', +] + +import torch +import torch.nn.functional as F +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_shape_from_value_info +from onnx2torch.utils.common import onnx_mapping_from_node + +_IN_CLASS_FROM_SPATIAL_RANK = { + 0: nn.InstanceNorm1d, + 1: nn.InstanceNorm1d, + 2: nn.InstanceNorm2d, + 3: nn.InstanceNorm3d, +} + + +class OnnxInstanceNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring + def __init__(self, momentum: float, epsilon: float): + super().__init__() + self.momentum = momentum + self.epsilon = epsilon + + def forward( # pylint: disable=missing-function-docstring + self, + input_data: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + ) -> torch.Tensor: + return F.instance_norm( + input=input_data, + running_mean=None, + running_var=None, + weight=weight, + bias=bias, + use_input_stats=True, + momentum=self.momentum, + eps=self.epsilon, + ) + + +@add_converter(operation_type='InstanceNormalization', version=1) +@add_converter(operation_type='InstanceNormalization', version=6) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + node_attributes = node.attributes + epsilon = node_attributes.get('epsilon', 1e-5) + momentum = 0.1 + + if all(value_name in graph.initializers for value_name in node.input_values[1:]): + input_value_info = graph.value_info[node.input_values[0]] + input_shape = get_shape_from_value_info(input_value_info) + spatial_rank = len(input_shape) - 2 + try: + in_class = _IN_CLASS_FROM_SPATIAL_RANK[spatial_rank] + except KeyError as exc: + raise NotImplementedError( + f'InstanceNorm operation with spatial rank == {spatial_rank} is not implemented' + ) from exc + + scale_value_name = node.input_values[1] + bias_value_name = node.input_values[2] + + scale = graph.initializers[scale_value_name].to_torch() + torch_module = in_class( + num_features=scale.size()[0], + eps=epsilon, + momentum=momentum, + affine=True, + track_running_stats=False, + ) + with torch.no_grad(): + torch_module.weight.data = graph.initializers[scale_value_name].to_torch() + torch_module.bias.data = graph.initializers[bias_value_name].to_torch() + + onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values) + else: + torch_module = OnnxInstanceNorm(momentum=momentum, epsilon=epsilon) + onnx_mapping = onnx_mapping_from_node(node) + + return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping) diff --git a/operators.md b/operators.md index 6bda700f..e12082e7 100644 --- a/operators.md +++ b/operators.md @@ -60,7 +60,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops | Hardmax | N | | | Identity | Y | | | If | N | | -| InstanceNormalization | N | | +| InstanceNormalization | Y | | | IsInf | N | | | IsNaN | N | | | LRN | Y | | diff --git a/tests/node_converters/instance_norm_test.py b/tests/node_converters/instance_norm_test.py new file mode 100644 index 00000000..71bd6fc4 --- /dev/null +++ b/tests/node_converters/instance_norm_test.py @@ -0,0 +1,47 @@ +from typing import List + +import numpy as np +import onnx +import pytest + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +@pytest.mark.parametrize('parameters_as_inputs', (True, False)) +@pytest.mark.parametrize( + 'input_shape', + ( + # 1d + [2, 3, 16], + [2, 1, 7], + # 2d + [2, 3, 16, 16], + [2, 1, 7, 16], + # 3d + [2, 3, 16, 16, 16], + [2, 1, 16, 7, 16], + ), +) +def test_instance_norm( # pylint: disable=missing-function-docstring + input_shape: List[int], + parameters_as_inputs: bool, +) -> None: + num_features = input_shape[1] + x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) + scale = np.random.uniform(low=0.0, high=1.0, size=num_features).astype(np.float32) + bias = np.random.uniform(low=-1.0, high=1.0, size=num_features).astype(np.float32) + + inputs = {'input': x} + parameters = {'scale': scale, 'bias': bias} + initializers = {} + + if parameters_as_inputs: + inputs.update(parameters) + else: + initializers.update(parameters) + + node = onnx.helper.make_node(op_type='InstanceNormalization', inputs=['input', 'scale', 'bias'], outputs=['y']) + + model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs) + check_onnx_model(onnx_model=model, onnx_inputs=inputs, atol_onnx_torch=1e-6, atol_torch_cpu_cuda=1e-6)