From dd2430a0e4eb4c9724d001792fba1aec7769c8f1 Mon Sep 17 00:00:00 2001 From: 11happy Date: Sun, 5 Jan 2025 00:42:53 +0530 Subject: [PATCH 1/5] implement replicate{1,2,3} pad Signed-off-by: 11happy --- src/frontends/pytorch/src/op/pad.cpp | 8 ++++ src/frontends/pytorch/src/op_table.cpp | 4 ++ tests/layer_tests/pytorch_tests/test_pad.py | 45 ++++++++++++++++++++- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/pad.cpp b/src/frontends/pytorch/src/op/pad.cpp index c8f35a0afab71a..28fb882379ffe0 100644 --- a/src/frontends/pytorch/src/op/pad.cpp +++ b/src/frontends/pytorch/src/op/pad.cpp @@ -134,6 +134,14 @@ OutputVector translate_reflection_pad_nd_fx(const NodeContext& context) { return translate_pad_common(context, data, paddings, pad_value, "reflect"); } +OutputVector translate_replicate_pad_nd_fx{const NodeContext & context} { + num_inputs_check(context, 2, 2); + auto data = context.get_input(0); + auto paddings = context.const_input>(1); + Output pad_value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); + return translate_pad_common(context, data, paddings, pad_value, "replicate"); +} + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index a73c13814d7663..6be1780c405264 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -297,6 +297,7 @@ OP_CONVERTER(translate_new_zeros_fx); OP_CONVERTER(translate_ones_fx); OP_CONVERTER(translate_ones_like_fx); OP_CONVERTER(translate_reflection_pad_nd_fx); +OP_CONVERTER(translate_replicate_pad_nd_fx); OP_CONVERTER(translate_reshape_fx); OP_CONVERTER(translate_rsub_fx); OP_CONVERTER(translate_scalar_tensor_fx); @@ -930,6 +931,9 @@ const std::unordered_map get_supported_ops_fx() { {"aten.reflection_pad1d.default", op::translate_reflection_pad_nd_fx}, {"aten.reflection_pad2d.default", op::translate_reflection_pad_nd_fx}, {"aten.reflection_pad3d.default", op::translate_reflection_pad_nd_fx}, + {"aten.replicate_pad1d.default", op::translate_replicate_pad_nd_fx}, + {"aten.replicate_pad2d.default", op::translate_replicate_pad_nd_fx}, + {"aten.replicate_pad3d.default", op::translate_replicate_pad_nd_fx}, {"aten.relu.default", op::translate_1to1_match_1_inputs}, {"aten.relu_.default", op::inplace_op>}, {"aten.repeat.default", op::translate_1to1_match_2_inputs}, diff --git a/tests/layer_tests/pytorch_tests/test_pad.py b/tests/layer_tests/pytorch_tests/test_pad.py index 92bf397f65999b..bf42658ba53eed 100644 --- a/tests/layer_tests/pytorch_tests/test_pad.py +++ b/tests/layer_tests/pytorch_tests/test_pad.py @@ -219,9 +219,9 @@ def __init__(self, pads): if ndim == 1: self.pad = torch.nn.ReflectionPad1d(pads) elif ndim == 2: - self.pad = torch.nn.ReflectionPad1d(pads) + self.pad = torch.nn.ReflectionPad2d(pads) elif ndim == 3: - self.pad = torch.nn.ReflectionPad1d(pads) + self.pad = torch.nn.ReflectionPad3d(pads) else: raise Exception("Unsupported pads") @@ -244,3 +244,44 @@ def test_reflection_padnd(self, pads, dtype, ie_device, precision, ir_version): print(ndim) self._test(*self.create_model(pads), ie_device, precision, ir_version, kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) + +class TestReplicatePad(PytorchLayerTest): + def _prepare_input(self, ndim=4, dtype="float32"): + import numpy as np + input_5d_shape = [5,9,1,1,2,4] + return (np.random.randn(*input_5d_shape[:ndim]).astype(dtype),) + + def create_model(self, pads): + import torch + import torch.nn.functional as F + + class aten_pad(torch.nn.Module): + def __init__(self, pads): + super().__init__() + ndim = len(pads) / 2 + if ndim == 1: + self.pad = torch.nn.ReplicationPad1d(pads) + elif ndim == 2: + self.pad = torch.nn.ReplicationPad2d(pads) + elif ndim == 3: + self.pad = torch.nn.ReplicationPad3d(pads) + else: + raise Exception("Unsupported pads") + + def forward(self, x): + return self.pad(x) + + return aten_pad(pads), None, "aten::pad" + + @pytest.mark.parametrize("dtype", ["float32", "float64", "int32"]) + @pytest.mark.parametrize("pads", [ + (1, 2), + (1, 2, 3, 4), + (1, 2, 3, 4, 3, 2), + ]) + @pytest.mark.nightly + @pytest.mark.precommit_torch_export + def test_replicate_padnd(self, pads, dtype, ie_device, precision, ir_version): + ndim = len(pads) // 2 + 2 + self._test(*self.create_model(pads), ie_device, precision, ir_version, + kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) \ No newline at end of file From 4859af0158e01842a12dd28b41a171edd539f66a Mon Sep 17 00:00:00 2001 From: 11happy Date: Tue, 14 Jan 2025 01:06:24 +0530 Subject: [PATCH 2/5] refactor: add suggested changes Signed-off-by: 11happy --- src/frontends/pytorch/src/op_table.cpp | 3 + tests/layer_tests/pytorch_tests/test_pad.py | 105 +++++++++++++++++--- 2 files changed, 94 insertions(+), 14 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 6be1780c405264..b498e7d75bef66 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -615,6 +615,9 @@ const std::unordered_map get_supported_ops_ts() { {"aten::remainder", op::translate_remainder}, {"aten::repeat", op::translate_1to1_match_2_inputs}, {"aten::repeat_interleave", op::translate_repeat_interleave}, + {"aten::replicate_pad1d",op::translate_replicate_pad_nd_fx}, + {"aten::replicate_pad2d",op::translate_replicate_pad_nd_fx}, + {"aten::replicate_pad3d",op::translate_replicate_pad_nd_fx}, {"aten::reshape", op::translate_reshape}, {"aten::reshape_as", op::translate_reshape_as}, // TO DO: enable behaviour for resolve_conj and resolve_neg complex tensors, diff --git a/tests/layer_tests/pytorch_tests/test_pad.py b/tests/layer_tests/pytorch_tests/test_pad.py index bf42658ba53eed..c194dad0fd26a6 100644 --- a/tests/layer_tests/pytorch_tests/test_pad.py +++ b/tests/layer_tests/pytorch_tests/test_pad.py @@ -245,7 +245,7 @@ def test_reflection_padnd(self, pads, dtype, ie_device, precision, ir_version): self._test(*self.create_model(pads), ie_device, precision, ir_version, kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) -class TestReplicatePad(PytorchLayerTest): +class TestReplicatePad1D(PytorchLayerTest): def _prepare_input(self, ndim=4, dtype="float32"): import numpy as np input_5d_shape = [5,9,1,1,2,4] @@ -258,15 +258,7 @@ def create_model(self, pads): class aten_pad(torch.nn.Module): def __init__(self, pads): super().__init__() - ndim = len(pads) / 2 - if ndim == 1: - self.pad = torch.nn.ReplicationPad1d(pads) - elif ndim == 2: - self.pad = torch.nn.ReplicationPad2d(pads) - elif ndim == 3: - self.pad = torch.nn.ReplicationPad3d(pads) - else: - raise Exception("Unsupported pads") + self.pad = torch.nn.ReplicationPad1d(pads) def forward(self, x): return self.pad(x) @@ -275,13 +267,98 @@ def forward(self, x): @pytest.mark.parametrize("dtype", ["float32", "float64", "int32"]) @pytest.mark.parametrize("pads", [ + 1, + 2, + 3, (1, 2), - (1, 2, 3, 4), - (1, 2, 3, 4, 3, 2), + (2, 1), + (2, 3), + (3, 4), ]) + @pytest.mark.nightly + @pytest.mark.precommit @pytest.mark.precommit_torch_export def test_replicate_padnd(self, pads, dtype, ie_device, precision, ir_version): - ndim = len(pads) // 2 + 2 + ndim = 3 + self._test(*self.create_model(pads), ie_device, precision, ir_version, + kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) + +class TestReplicatePad2D(PytorchLayerTest): + def _prepare_input(self, ndim=4, dtype="float32"): + import numpy as np + input_5d_shape = [5,9,1,1,2,4] + return (np.random.randn(*input_5d_shape[:ndim]).astype(dtype),) + + def create_model(self, pads): + import torch + import torch.nn.functional as F + + class aten_pad(torch.nn.Module): + def __init__(self, pads): + super().__init__() + self.pad = torch.nn.ReplicationPad2d(pads) + + def forward(self, x): + return self.pad(x) + + return aten_pad(pads), None, "aten::pad" + + @pytest.mark.parametrize("dtype", ["float32", "float64", "int32"]) + @pytest.mark.parametrize("pads", [ + 1, + 2, + 3, + (1, 2, 2, 1), + (2, 1, 3, 4), + (2, 3, 1, 2), + (3, 4, 5, 6), + ]) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + def test_replicate_padnd(self, pads, dtype, ie_device, precision, ir_version): + ndim = 4 self._test(*self.create_model(pads), ie_device, precision, ir_version, - kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) \ No newline at end of file + kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) + +class TestReplicatePad3D(PytorchLayerTest): + def _prepare_input(self, ndim=4, dtype="float32"): + import numpy as np + input_5d_shape = [5,9,1,1,2,4] + return (np.random.randn(*input_5d_shape[:ndim]).astype(dtype),) + + def create_model(self, pads): + import torch + import torch.nn.functional as F + + class aten_pad(torch.nn.Module): + def __init__(self, pads): + super().__init__() + self.pad = torch.nn.ReplicationPad3d(pads) + + def forward(self, x): + return self.pad(x) + + return aten_pad(pads), None, "aten::pad" + + @pytest.mark.parametrize("dtype", ["float32", "float64", "int32"]) + @pytest.mark.parametrize("pads", [ + 1, + 2, + 3, + (1, 2, 2, 1, 3, 4), + (2, 1, 3, 4, 2, 1), + (2, 3, 1, 2, 2, 1), + (3, 4, 5, 6, 1, 2), + ]) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + def test_replicate_padnd(self, pads, dtype, ie_device, precision, ir_version): + ndim = 5 + self._test(*self.create_model(pads), ie_device, precision, ir_version, + kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype}) + From 388964f4a8d12555d10a8c311be0a8d059de0488 Mon Sep 17 00:00:00 2001 From: 11happy Date: Sat, 25 Jan 2025 21:03:57 +0530 Subject: [PATCH 3/5] clean code style: run clang format Signed-off-by: 11happy --- src/frontends/pytorch/src/op_table.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index bb7afdcff13e57..e0239a62ae5992 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -631,9 +631,9 @@ const std::unordered_map get_supported_ops_ts() { {"aten::remainder", op::translate_remainder}, {"aten::repeat", op::translate_1to1_match_2_inputs}, {"aten::repeat_interleave", op::translate_repeat_interleave}, - {"aten::replicate_pad1d",op::translate_replicate_pad_nd_fx}, - {"aten::replicate_pad2d",op::translate_replicate_pad_nd_fx}, - {"aten::replicate_pad3d",op::translate_replicate_pad_nd_fx}, + {"aten::replicate_pad1d", op::translate_replicate_pad_nd_fx}, + {"aten::replicate_pad2d", op::translate_replicate_pad_nd_fx}, + {"aten::replicate_pad3d", op::translate_replicate_pad_nd_fx}, {"aten::reshape", op::translate_reshape}, {"aten::reshape_as", op::translate_reshape_as}, // TO DO: enable behaviour for resolve_conj and resolve_neg complex tensors, From a759cc787dbc03fdd584cccc72d45b10df87cb86 Mon Sep 17 00:00:00 2001 From: 11happy Date: Sun, 2 Feb 2025 14:00:44 +0530 Subject: [PATCH 4/5] fix: fixed failing build CI Signed-off-by: 11happy --- src/frontends/pytorch/src/op/pad.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/pad.cpp b/src/frontends/pytorch/src/op/pad.cpp index d5f2d20e573b01..6cdfd17e6725f5 100644 --- a/src/frontends/pytorch/src/op/pad.cpp +++ b/src/frontends/pytorch/src/op/pad.cpp @@ -134,7 +134,7 @@ OutputVector translate_reflection_pad_nd_fx(const NodeContext& context) { return translate_pad_common(context, data, paddings, pad_value, "reflect"); } -OutputVector translate_replicate_pad_nd_fx{const NodeContext & context} { +OutputVector translate_replicate_pad_nd_fx(const NodeContext & context) { num_inputs_check(context, 2, 2); auto data = context.get_input(0); auto paddings = context.const_input>(1); From ac6579aa046e4da573c790a329f1cd0ff0b41ce7 Mon Sep 17 00:00:00 2001 From: 11happy Date: Mon, 3 Feb 2025 06:20:17 +0530 Subject: [PATCH 5/5] fix: git artifact Signed-off-by: 11happy --- src/frontends/pytorch/src/op_table.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index e0239a62ae5992..b69b4e2b1d61e9 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -305,12 +305,9 @@ OP_CONVERTER(translate_new_zeros_fx); OP_CONVERTER(translate_ones_fx); OP_CONVERTER(translate_ones_like_fx); OP_CONVERTER(translate_reflection_pad_nd_fx); -<<<<<<< HEAD OP_CONVERTER(translate_replicate_pad_nd_fx); OP_CONVERTER(translate_reshape_fx); -======= OP_CONVERTER(translate_repeat_fx); ->>>>>>> origin OP_CONVERTER(translate_rsub_fx); OP_CONVERTER(translate_scalar_tensor_fx); OP_CONVERTER(translate_scaled_dot_product_attention_fx);