Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] : implement support for replicated{1,2,3} pad #28271

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/frontends/pytorch/src/op/pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ OutputVector translate_reflection_pad_nd_fx(const NodeContext& context) {
return translate_pad_common(context, data, paddings, pad_value, "reflect");
}

OutputVector translate_replicate_pad_nd_fx{const NodeContext & context} {
num_inputs_check(context, 2, 2);
auto data = context.get_input(0);
auto paddings = context.const_input<std::vector<int64_t>>(1);
Output<Node> pad_value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
return translate_pad_common(context, data, paddings, pad_value, "replicate");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that logic from translate_pad_common is suitable for replication operation.
That is because a format of paddings is as follows <pad_begin_axis1>, <pad_end_axis1>, <pad_begin_axis2>, <pad_end_axis2>, ... in case of tuple.

May be we need to do as follows:

  1. do broadcast of padding to convert to tuple case
  2. create a pad vector with zeros with length equal to rank of input
  3. scatter elements from broadcasted padding to pad vector with known indices defined for each 1d, 2d and 3d cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not able to currently get why this logic maybe incorrect can you please give an example to explain. it would be very helpful. Thank you

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @11happy,
I meant logic for preparation of pad vector can be simplified but I understand it may be out-of-scope for this PR because it is used by other translators.
Let us see tests if they pass. please resolve conflicts and I will trigger CI for your PR.

Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved the conflicts.
Thank you

}

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ OP_CONVERTER(translate_new_zeros_fx);
OP_CONVERTER(translate_ones_fx);
OP_CONVERTER(translate_ones_like_fx);
OP_CONVERTER(translate_reflection_pad_nd_fx);
OP_CONVERTER(translate_replicate_pad_nd_fx);
OP_CONVERTER(translate_reshape_fx);
OP_CONVERTER(translate_rsub_fx);
OP_CONVERTER(translate_scalar_tensor_fx);
Expand Down Expand Up @@ -930,6 +931,9 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.reflection_pad1d.default", op::translate_reflection_pad_nd_fx},
{"aten.reflection_pad2d.default", op::translate_reflection_pad_nd_fx},
{"aten.reflection_pad3d.default", op::translate_reflection_pad_nd_fx},
{"aten.replicate_pad1d.default", op::translate_replicate_pad_nd_fx},
{"aten.replicate_pad2d.default", op::translate_replicate_pad_nd_fx},
{"aten.replicate_pad3d.default", op::translate_replicate_pad_nd_fx},
rkazants marked this conversation as resolved.
Show resolved Hide resolved
{"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten.repeat.default", op::translate_1to1_match_2_inputs<opset10::Tile>},
Expand Down
45 changes: 43 additions & 2 deletions tests/layer_tests/pytorch_tests/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def __init__(self, pads):
if ndim == 1:
self.pad = torch.nn.ReflectionPad1d(pads)
elif ndim == 2:
self.pad = torch.nn.ReflectionPad1d(pads)
self.pad = torch.nn.ReflectionPad2d(pads)
elif ndim == 3:
self.pad = torch.nn.ReflectionPad1d(pads)
self.pad = torch.nn.ReflectionPad3d(pads)
else:
raise Exception("Unsupported pads")

Expand All @@ -244,3 +244,44 @@ def test_reflection_padnd(self, pads, dtype, ie_device, precision, ir_version):
print(ndim)
self._test(*self.create_model(pads), ie_device, precision, ir_version,
kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype})

class TestReplicatePad(PytorchLayerTest):
rkazants marked this conversation as resolved.
Show resolved Hide resolved
def _prepare_input(self, ndim=4, dtype="float32"):
import numpy as np
input_5d_shape = [5,9,1,1,2,4]
return (np.random.randn(*input_5d_shape[:ndim]).astype(dtype),)

def create_model(self, pads):
import torch
import torch.nn.functional as F

class aten_pad(torch.nn.Module):
def __init__(self, pads):
super().__init__()
ndim = len(pads) / 2
if ndim == 1:
self.pad = torch.nn.ReplicationPad1d(pads)
elif ndim == 2:
self.pad = torch.nn.ReplicationPad2d(pads)
elif ndim == 3:
self.pad = torch.nn.ReplicationPad3d(pads)
else:
raise Exception("Unsupported pads")

def forward(self, x):
return self.pad(x)

return aten_pad(pads), None, "aten::pad"

@pytest.mark.parametrize("dtype", ["float32", "float64", "int32"])
@pytest.mark.parametrize("pads", [
(1, 2),
rkazants marked this conversation as resolved.
Show resolved Hide resolved
(1, 2, 3, 4),
(1, 2, 3, 4, 3, 2),
])
@pytest.mark.nightly
@pytest.mark.precommit_torch_export
rkazants marked this conversation as resolved.
Show resolved Hide resolved
def test_replicate_padnd(self, pads, dtype, ie_device, precision, ir_version):
ndim = len(pads) // 2 + 2
self._test(*self.create_model(pads), ie_device, precision, ir_version,
kwargs_to_prepare_input={"ndim": ndim, "dtype": dtype})
Loading