From c08e01d6d71490b5a2010caad9c78fb7b57c2044 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 10 Nov 2023 18:00:04 +0100 Subject: [PATCH] [PT FE] Optimize reverseprop in pytorch frontend (#20989) * [PT FE] Optimize reverseprop in pytorch frontend * Add transformation * Improve readability --------- Co-authored-by: Alina Kladieva --- src/frontends/pytorch/src/frontend.cpp | 2 + .../pytorch/src/helper_ops/gather_assign.hpp | 40 ++++++ .../pytorch/src/helper_ops/internal_op.hpp | 7 ++ .../pytorch/src/helper_ops/slice_assign.hpp | 64 ++++++++++ src/frontends/pytorch/src/node_context.cpp | 12 +- .../src/transforms/reverseprop_resolver.cpp | 119 ++++++++++++++++++ .../src/transforms/reverseprop_resolver.hpp | 27 ++++ .../pytorch/src/translate_session.cpp | 101 +++++---------- .../pytorch/src/translate_session.hpp | 8 +- 9 files changed, 297 insertions(+), 83 deletions(-) create mode 100644 src/frontends/pytorch/src/helper_ops/gather_assign.hpp create mode 100644 src/frontends/pytorch/src/helper_ops/slice_assign.hpp create mode 100644 src/frontends/pytorch/src/transforms/reverseprop_resolver.cpp create mode 100644 src/frontends/pytorch/src/transforms/reverseprop_resolver.hpp diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 36d4027dcc426f..1f021dfba441f5 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -38,6 +38,7 @@ #include "transforms/prim_list_unpack_replacer.hpp" #include "transforms/prim_tuple_unpack_parameter_replacer.hpp" #include "transforms/quantized_node_remover.hpp" +#include "transforms/reverseprop_resolver.hpp" #include "transforms/rfftn_complex_replacer.hpp" #include "transforms/softmax_reshape_elimination.hpp" #include "transforms/string_equality_replacer.hpp" @@ -204,6 +205,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); // Second pass of AlignTypesRemoval after all converting transformations diff --git a/src/frontends/pytorch/src/helper_ops/gather_assign.hpp b/src/frontends/pytorch/src/helper_ops/gather_assign.hpp new file mode 100644 index 00000000000000..eadc9dfda7ecdf --- /dev/null +++ b/src/frontends/pytorch/src/helper_ops/gather_assign.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "helper_ops/internal_op.hpp" +#include "openvino/frontend/decoder.hpp" +#include "openvino/op/op.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { + +class GatherAssign : public InternalReverseOperation { +public: + OPENVINO_OP("GatherAssign", "internal", InternalReverseOperation); + + GatherAssign(const Output& data, + const Output& updates, + const Output& indices, + const Output& axis) + : InternalReverseOperation({data, updates, indices, axis}) { + validate_and_infer_types(); + } + + void validate_and_infer_types() override { + auto data = input_value(0); + set_output_type(0, data.get_element_type(), data.get_partial_shape()); + } + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override { + check_new_args_count(this, new_args); + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3)); + } +}; +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/helper_ops/internal_op.hpp b/src/frontends/pytorch/src/helper_ops/internal_op.hpp index 510654dce8620a..8d9cc9c727083c 100644 --- a/src/frontends/pytorch/src/helper_ops/internal_op.hpp +++ b/src/frontends/pytorch/src/helper_ops/internal_op.hpp @@ -8,6 +8,7 @@ #include #include "openvino/frontend/decoder.hpp" +#include "openvino/op/op.hpp" #include "pt_framework_node.hpp" #include "utils.hpp" @@ -51,6 +52,12 @@ class InternalOperation : public PtFrameworkNode { set_attrs(attrs); } }; + +class InternalReverseOperation : public ov::op::Op { +public: + OPENVINO_OP("InternalReverseOperation", "internal"); + InternalReverseOperation(const OutputVector& inputs) : ov::op::Op(inputs) {} +}; } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/src/frontends/pytorch/src/helper_ops/slice_assign.hpp b/src/frontends/pytorch/src/helper_ops/slice_assign.hpp new file mode 100644 index 00000000000000..238cf7aefee7aa --- /dev/null +++ b/src/frontends/pytorch/src/helper_ops/slice_assign.hpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "helper_ops/internal_op.hpp" +#include "openvino/frontend/decoder.hpp" +#include "openvino/op/op.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { + +class SliceAssign : public InternalReverseOperation { +public: + OPENVINO_OP("SliceAssign", "internal", InternalReverseOperation); + + SliceAssign(const Output& data, + const Output& updates, + const Output& start, + const Output& stop, + const Output& step) + : InternalReverseOperation({data, updates, start, stop, step}) { + validate_and_infer_types(); + } + + SliceAssign(const Output& data, + const Output& updates, + const Output& start, + const Output& stop, + const Output& step, + const Output& axes) + : InternalReverseOperation({data, updates, start, stop, step, axes}) { + validate_and_infer_types(); + } + + void validate_and_infer_types() override { + auto data = input_value(0); + set_output_type(0, data.get_element_type(), data.get_partial_shape()); + } + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override { + check_new_args_count(this, new_args); + if (new_args.size() == 5) { + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4)); + } else { + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + new_args.at(5)); + } + } +}; +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/node_context.cpp b/src/frontends/pytorch/src/node_context.cpp index 5d8a138a52f1ef..4c6bf2e9e5080f 100644 --- a/src/frontends/pytorch/src/node_context.cpp +++ b/src/frontends/pytorch/src/node_context.cpp @@ -65,25 +65,25 @@ void NodeContext::mutate_input(size_t index, Output ov_output) const { auto back_input_id = input_id; auto back_node_input = ov_output; while (m_translate_session->m_may_be_alias.count(back_input_id)) { - // Create node to backprop data. While loop is needed for the cases when alias to tensor point to another alias - // to tensor. In that case we need to create a chain of backprop ops + // Create node to reverseprop data. While loop is needed for the cases when alias to tensor point to another + // alias to tensor. In that case we need to create a chain of reverseprop ops size_t in_tensor; std::shared_ptr node; Output node_converted_output; std::tie(in_tensor, node, node_converted_output) = m_translate_session->m_may_be_alias.at(back_input_id); - auto backprop_node = m_translate_session->get_backprop_op(node, node_converted_output, back_node_input); + auto reverseprop_node = m_translate_session->get_reverseprop_op(node, node_converted_output, back_node_input); if (m_tensor_map->count(in_tensor)) { // Tensor is not found in the scope of this body, need to get it from internal context and mark mutated OPENVINO_DEBUG << "Couldn't find in the current body the initial aliased tensor: " << in_tensor << " for operation: " << node->get_op_type() << " creating new body input."; get_tensor_from_model_or_create_input(in_tensor); } - m_translate_session->encode_tensor_name(backprop_node, in_tensor); - (*m_tensor_map)[in_tensor] = backprop_node; + m_translate_session->encode_tensor_name(reverseprop_node, in_tensor); + (*m_tensor_map)[in_tensor] = reverseprop_node; m_mutated_tensors->insert(in_tensor); OPENVINO_DEBUG << "Propagated back data from tensor: " << back_input_id << " to tensor: " << in_tensor << ".\n"; back_input_id = in_tensor; - back_node_input = backprop_node; + back_node_input = reverseprop_node; } } diff --git a/src/frontends/pytorch/src/transforms/reverseprop_resolver.cpp b/src/frontends/pytorch/src/transforms/reverseprop_resolver.cpp new file mode 100644 index 00000000000000..4bdc28b07f2fc7 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/reverseprop_resolver.cpp @@ -0,0 +1,119 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "reverseprop_resolver.hpp" + +#include +#include + +#include "helper_ops/gather_assign.hpp" +#include "helper_ops/internal_op.hpp" +#include "helper_ops/slice_assign.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reduce_prod.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_nd_update.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +using namespace ov::pass; +using namespace ov::op; + +ReversepropResolver::ReversepropResolver() { + auto reverse_op = pattern::wrap_type(); + + ov::matcher_pass_callback callback = [](pattern::Matcher& m) { + auto base_op = m.get_match_root(); + // Apply this transformation only to starting reverse operation + if (ov::as_type_ptr(base_op->get_input_node_shared_ptr(1))) + return false; + + auto curr_op = base_op; + std::vector> rev_ops; + while (ov::as_type_ptr(curr_op)) { + rev_ops.push_back(curr_op); + auto target_inputs = curr_op->get_output_target_inputs(0); + if (target_inputs.size() != 1) + break; + curr_op = target_inputs.begin()->get_node()->shared_from_this(); + } + if (rev_ops.size() < 1) + return false; + + ov::pass::NodeRegistry rg; + auto zero = v0::Constant::create(element::i64, Shape{}, {0}); + auto one = v0::Constant::create(element::i64, Shape{}, {1}); + auto neg_one_1d = v0::Constant::create(element::i64, Shape{1}, {-1}); + auto scattering_shape = v0::Constant::create(element::i64, Shape{2}, {-1, 1}); + + // Get 1d indices [0..numel) for whole input tensor + auto start_op = rev_ops.back(); + auto data_to_insert_into = start_op->input_value(0); + auto input_shape = rg.make(data_to_insert_into, element::i64); + auto numel = rg.make(input_shape, zero, false); + auto full_data_indices_1d = rg.make(zero, numel, one, element::i64); + auto full_data_indices = rg.make(full_data_indices_1d, input_shape, false); + + // cut indices in accordance with operations + Output data_indices = full_data_indices; + for (auto it = rev_ops.rbegin(); it != rev_ops.rend(); ++it) { + curr_op = *it; + if (ov::as_type_ptr(curr_op)) { + if (curr_op->get_input_size() == 6) { + data_indices = rg.make(data_indices, + curr_op->input_value(2), + curr_op->input_value(3), + curr_op->input_value(4), + curr_op->input_value(5)); + } else if (curr_op->get_input_size() == 5) { + data_indices = rg.make(data_indices, + curr_op->input_value(2), + curr_op->input_value(3), + curr_op->input_value(4)); + } else { + return false; + } + } else if (ov::as_type_ptr(curr_op)) { + data_indices = rg.make(data_indices, curr_op->input_value(2), curr_op->input_value(3)); + } else { + return false; + } + } + + // Scatter in flattened tensor with indices and flattened data to be inserted + auto data_to_insert_into_1d = rg.make(data_to_insert_into, neg_one_1d, false); + auto data_indices_1d = rg.make(data_indices, scattering_shape, false); + auto to_be_inserted_data_1d = rg.make(base_op->input_value(1), neg_one_1d, false); + auto updated_data_1d = + rg.make(data_to_insert_into_1d, data_indices_1d, to_be_inserted_data_1d); + + // Reshape to initial shape + auto res_node = rg.make(updated_data_1d, input_shape, false); + copy_runtime_info_and_name(base_op, rg.get()); + start_op->output(0).replace(res_node); + + return true; + }; + + auto m = + std::make_shared(reverse_op, "ov::frontend::pytorch::pass::ReversepropResolver"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/reverseprop_resolver.hpp b/src/frontends/pytorch/src/transforms/reverseprop_resolver.hpp new file mode 100644 index 00000000000000..d07162889c7b9e --- /dev/null +++ b/src/frontends/pytorch/src/transforms/reverseprop_resolver.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +/** + * Replace sequence of reverseprop operations with ScatterNdUpdate. + */ +class ReversepropResolver : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::ReversepropResolver"); + ReversepropResolver(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/translate_session.cpp b/src/frontends/pytorch/src/translate_session.cpp index f08a7d08c7a36a..16090fe0e42931 100644 --- a/src/frontends/pytorch/src/translate_session.cpp +++ b/src/frontends/pytorch/src/translate_session.cpp @@ -4,18 +4,11 @@ #include "translate_session.hpp" +#include "helper_ops/gather_assign.hpp" +#include "helper_ops/slice_assign.hpp" #include "input_model.hpp" -#include "openvino/op/constant.hpp" #include "openvino/op/gather.hpp" -#include "openvino/op/parameter.hpp" -#include "openvino/op/range.hpp" -#include "openvino/op/reduce_prod.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/result.hpp" -#include "openvino/op/scatter_nd_update.hpp" -#include "openvino/op/shape_of.hpp" #include "openvino/op/slice.hpp" -#include "openvino/op/transpose.hpp" #include "openvino/util/log.hpp" #include "place.hpp" #include "pt_framework_node.hpp" @@ -344,92 +337,54 @@ size_t TranslateSession::decode_tensor_name(const Output& output) { } namespace { -Output slice_backprop(const Output& slice_output, const Output& value) { +Output slice_reverseprop(const Output& slice_output, const Output& value) { auto slice_node = slice_output.get_node_shared_ptr(); FRONT_END_OP_CONVERSION_CHECK(ov::as_type_ptr(slice_node), "Conversion rule for aten::slice doesn't contain Slice node."); - auto zero = v0::Constant::create(element::i64, Shape{}, {0}); - auto one = v0::Constant::create(element::i64, Shape{}, {1}); - auto neg_one_1d = v0::Constant::create(element::i64, Shape{1}, {-1}); - auto scattering_shape = v0::Constant::create(element::i64, Shape{2}, {-1, 1}); - - // Get 1d indices [0..numel) auto to_insert_data = slice_node->input_value(0); - auto input_shape = std::make_shared(to_insert_data, element::i64); - auto numel = std::make_shared(input_shape, zero, false); - auto full_data_indices_1d = std::make_shared(zero, numel, one, element::i64); - - // Slice indices by same start, stop, slice, axes as initial Slice - auto full_data_indices = std::make_shared(full_data_indices_1d, input_shape, false); - Output data_indices; + Output res; if (slice_node->get_input_size() == 5) { - data_indices = std::make_shared(full_data_indices, - slice_node->input_value(1), - slice_node->input_value(2), - slice_node->input_value(3), - slice_node->input_value(4)); + res = std::make_shared(to_insert_data, + value, + slice_node->input_value(1), + slice_node->input_value(2), + slice_node->input_value(3), + slice_node->input_value(4)); } else if (slice_node->get_input_size() == 4) { - data_indices = std::make_shared(full_data_indices, - slice_node->input_value(1), - slice_node->input_value(2), - slice_node->input_value(3)); + res = std::make_shared(to_insert_data, + value, + slice_node->input_value(1), + slice_node->input_value(2), + slice_node->input_value(3)); } else { FRONT_END_OP_CONVERSION_CHECK(false, "Incorrect number of Slice inputs"); } - // Scatter in flattened tensor with indices and flattened data to be inserted - auto to_insert_data_1d = std::make_shared(to_insert_data, neg_one_1d, false); - auto data_indices_1d = std::make_shared(data_indices, scattering_shape, false); - auto to_be_inserted_data_1d = std::make_shared(value, neg_one_1d, false); - auto updated_data_1d = - std::make_shared(to_insert_data_1d, data_indices_1d, to_be_inserted_data_1d); - - // Reshape to initial shape - return std::make_shared(updated_data_1d, input_shape, false); + return res; } -Output select_backprop(const Output& select_output, const Output& value) { +Output select_reverseprop(const Output& select_output, const Output& value) { auto gather_node = select_output.get_node_shared_ptr(); FRONT_END_OP_CONVERSION_CHECK(ov::as_type_ptr(gather_node), "Conversion rule for aten::select doesn't contain Gather node."); - auto zero = v0::Constant::create(element::i64, Shape{}, {0}); - auto one = v0::Constant::create(element::i64, Shape{}, {1}); - auto neg_one_1d = v0::Constant::create(element::i64, Shape{1}, {-1}); - auto scattering_shape = v0::Constant::create(element::i64, Shape{2}, {-1, 1}); - - // Get 1d indices [0..numel) auto to_insert_data = gather_node->input_value(0); - auto input_shape = std::make_shared(to_insert_data, element::i64); - auto numel = std::make_shared(input_shape, zero, false); - auto full_data_indices_1d = std::make_shared(zero, numel, one, element::i64); - - // Slice indices by same start, stop, slice, axes as initial Slice - auto full_data_indices = std::make_shared(full_data_indices_1d, input_shape, false); - Output data_indices = - std::make_shared(full_data_indices, gather_node->input_value(1), gather_node->input_value(2)); - - // Scatter in flattened tensor with indices and flattened data to be inserted - auto to_insert_data_1d = std::make_shared(to_insert_data, neg_one_1d, false); - auto data_indices_1d = std::make_shared(data_indices, scattering_shape, false); - auto to_be_inserted_data_1d = std::make_shared(value, neg_one_1d, false); - auto updated_data_1d = - std::make_shared(to_insert_data_1d, data_indices_1d, to_be_inserted_data_1d); - - // Reshape to initial shape - return std::make_shared(updated_data_1d, input_shape, false); + return std::make_shared(to_insert_data, + value, + gather_node->input_value(1), + gather_node->input_value(2)); } } // namespace -using BackpropCreatorFunction = std::function(const Output&, const Output&)>; +using ReversepropCreatorFunction = std::function(const Output&, const Output&)>; -Output TranslateSession::get_backprop_op(const std::shared_ptr& node, - const Output& direct_op_output, - const Output& value) { - std::map backprop_map = { - {"aten::slice", slice_backprop}, - {"aten::select", select_backprop}, +Output TranslateSession::get_reverseprop_op(const std::shared_ptr& node, + const Output& direct_op_output, + const Output& value) { + std::map backprop_map = { + {"aten::slice", slice_reverseprop}, + {"aten::select", select_reverseprop}, }; Output backprop_node; diff --git a/src/frontends/pytorch/src/translate_session.hpp b/src/frontends/pytorch/src/translate_session.hpp index 44ce6232caaa00..de65d1c4ed9eae 100644 --- a/src/frontends/pytorch/src/translate_session.hpp +++ b/src/frontends/pytorch/src/translate_session.hpp @@ -34,10 +34,10 @@ class TranslateSession { const TensorMap& external_tensor_map = {}, const std::shared_ptr& input_model = nullptr); - /// \brief Returns backprop operations for direct operation - Output get_backprop_op(const std::shared_ptr& node, - const Output& direct_op_output, - const Output& value); + /// \brief Returns reverseprop operations for direct operation + Output get_reverseprop_op(const std::shared_ptr& node, + const Output& direct_op_output, + const Output& value); /// \brief Writes pytorch tensor index into openvino tensor void encode_tensor_name(Output tensor_desc,