Skip to content

Commit

Permalink
[PT FE] Optimize reverseprop in pytorch frontend (#20989)
Browse files Browse the repository at this point in the history
* [PT FE] Optimize reverseprop in pytorch frontend

* Add transformation

* Improve readability

---------

Co-authored-by: Alina Kladieva <[email protected]>
  • Loading branch information
mvafin and akladiev authored Nov 10, 2023
1 parent e446dac commit c08e01d
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 83 deletions.
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "transforms/prim_list_unpack_replacer.hpp"
#include "transforms/prim_tuple_unpack_parameter_replacer.hpp"
#include "transforms/quantized_node_remover.hpp"
#include "transforms/reverseprop_resolver.hpp"
#include "transforms/rfftn_complex_replacer.hpp"
#include "transforms/softmax_reshape_elimination.hpp"
#include "transforms/string_equality_replacer.hpp"
Expand Down Expand Up @@ -204,6 +205,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
manager.register_pass<ov::frontend::pytorch::pass::U4BlockRepack>();
manager.register_pass<ov::frontend::pytorch::pass::ReversepropResolver>();
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
// Second pass of AlignTypesRemoval after all converting transformations
Expand Down
40 changes: 40 additions & 0 deletions src/frontends/pytorch/src/helper_ops/gather_assign.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "helper_ops/internal_op.hpp"
#include "openvino/frontend/decoder.hpp"
#include "openvino/op/op.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {

class GatherAssign : public InternalReverseOperation {
public:
OPENVINO_OP("GatherAssign", "internal", InternalReverseOperation);

GatherAssign(const Output<Node>& data,
const Output<Node>& updates,
const Output<Node>& indices,
const Output<Node>& axis)
: InternalReverseOperation({data, updates, indices, axis}) {
validate_and_infer_types();
}

void validate_and_infer_types() override {
auto data = input_value(0);
set_output_type(0, data.get_element_type(), data.get_partial_shape());
}

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override {
check_new_args_count(this, new_args);
return std::make_shared<GatherAssign>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}
};
} // namespace pytorch
} // namespace frontend
} // namespace ov
7 changes: 7 additions & 0 deletions src/frontends/pytorch/src/helper_ops/internal_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>

#include "openvino/frontend/decoder.hpp"
#include "openvino/op/op.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -51,6 +52,12 @@ class InternalOperation : public PtFrameworkNode {
set_attrs(attrs);
}
};

class InternalReverseOperation : public ov::op::Op {
public:
OPENVINO_OP("InternalReverseOperation", "internal");
InternalReverseOperation(const OutputVector& inputs) : ov::op::Op(inputs) {}
};
} // namespace pytorch
} // namespace frontend
} // namespace ov
64 changes: 64 additions & 0 deletions src/frontends/pytorch/src/helper_ops/slice_assign.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "helper_ops/internal_op.hpp"
#include "openvino/frontend/decoder.hpp"
#include "openvino/op/op.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {

class SliceAssign : public InternalReverseOperation {
public:
OPENVINO_OP("SliceAssign", "internal", InternalReverseOperation);

SliceAssign(const Output<Node>& data,
const Output<Node>& updates,
const Output<Node>& start,
const Output<Node>& stop,
const Output<Node>& step)
: InternalReverseOperation({data, updates, start, stop, step}) {
validate_and_infer_types();
}

SliceAssign(const Output<Node>& data,
const Output<Node>& updates,
const Output<Node>& start,
const Output<Node>& stop,
const Output<Node>& step,
const Output<Node>& axes)
: InternalReverseOperation({data, updates, start, stop, step, axes}) {
validate_and_infer_types();
}

void validate_and_infer_types() override {
auto data = input_value(0);
set_output_type(0, data.get_element_type(), data.get_partial_shape());
}

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override {
check_new_args_count(this, new_args);
if (new_args.size() == 5) {
return std::make_shared<SliceAssign>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4));
} else {
return std::make_shared<SliceAssign>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5));
}
}
};
} // namespace pytorch
} // namespace frontend
} // namespace ov
12 changes: 6 additions & 6 deletions src/frontends/pytorch/src/node_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,25 @@ void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const {
auto back_input_id = input_id;
auto back_node_input = ov_output;
while (m_translate_session->m_may_be_alias.count(back_input_id)) {
// Create node to backprop data. While loop is needed for the cases when alias to tensor point to another alias
// to tensor. In that case we need to create a chain of backprop ops
// Create node to reverseprop data. While loop is needed for the cases when alias to tensor point to another
// alias to tensor. In that case we need to create a chain of reverseprop ops
size_t in_tensor;
std::shared_ptr<TorchDecoder> node;
Output<Node> node_converted_output;
std::tie(in_tensor, node, node_converted_output) = m_translate_session->m_may_be_alias.at(back_input_id);
auto backprop_node = m_translate_session->get_backprop_op(node, node_converted_output, back_node_input);
auto reverseprop_node = m_translate_session->get_reverseprop_op(node, node_converted_output, back_node_input);
if (m_tensor_map->count(in_tensor)) {
// Tensor is not found in the scope of this body, need to get it from internal context and mark mutated
OPENVINO_DEBUG << "Couldn't find in the current body the initial aliased tensor: " << in_tensor
<< " for operation: " << node->get_op_type() << " creating new body input.";
get_tensor_from_model_or_create_input(in_tensor);
}
m_translate_session->encode_tensor_name(backprop_node, in_tensor);
(*m_tensor_map)[in_tensor] = backprop_node;
m_translate_session->encode_tensor_name(reverseprop_node, in_tensor);
(*m_tensor_map)[in_tensor] = reverseprop_node;
m_mutated_tensors->insert(in_tensor);
OPENVINO_DEBUG << "Propagated back data from tensor: " << back_input_id << " to tensor: " << in_tensor << ".\n";
back_input_id = in_tensor;
back_node_input = backprop_node;
back_node_input = reverseprop_node;
}
}

Expand Down
119 changes: 119 additions & 0 deletions src/frontends/pytorch/src/transforms/reverseprop_resolver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "reverseprop_resolver.hpp"

#include <memory>
#include <utility>

#include "helper_ops/gather_assign.hpp"
#include "helper_ops/internal_op.hpp"
#include "helper_ops/slice_assign.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

using namespace ov::pass;
using namespace ov::op;

ReversepropResolver::ReversepropResolver() {
auto reverse_op = pattern::wrap_type<InternalReverseOperation>();

ov::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto base_op = m.get_match_root();
// Apply this transformation only to starting reverse operation
if (ov::as_type_ptr<InternalReverseOperation>(base_op->get_input_node_shared_ptr(1)))
return false;

auto curr_op = base_op;
std::vector<std::shared_ptr<Node>> rev_ops;
while (ov::as_type_ptr<InternalReverseOperation>(curr_op)) {
rev_ops.push_back(curr_op);
auto target_inputs = curr_op->get_output_target_inputs(0);
if (target_inputs.size() != 1)
break;
curr_op = target_inputs.begin()->get_node()->shared_from_this();
}
if (rev_ops.size() < 1)
return false;

ov::pass::NodeRegistry rg;
auto zero = v0::Constant::create(element::i64, Shape{}, {0});
auto one = v0::Constant::create(element::i64, Shape{}, {1});
auto neg_one_1d = v0::Constant::create(element::i64, Shape{1}, {-1});
auto scattering_shape = v0::Constant::create(element::i64, Shape{2}, {-1, 1});

// Get 1d indices [0..numel) for whole input tensor
auto start_op = rev_ops.back();
auto data_to_insert_into = start_op->input_value(0);
auto input_shape = rg.make<v3::ShapeOf>(data_to_insert_into, element::i64);
auto numel = rg.make<v1::ReduceProd>(input_shape, zero, false);
auto full_data_indices_1d = rg.make<v4::Range>(zero, numel, one, element::i64);
auto full_data_indices = rg.make<v1::Reshape>(full_data_indices_1d, input_shape, false);

// cut indices in accordance with operations
Output<Node> data_indices = full_data_indices;
for (auto it = rev_ops.rbegin(); it != rev_ops.rend(); ++it) {
curr_op = *it;
if (ov::as_type_ptr<SliceAssign>(curr_op)) {
if (curr_op->get_input_size() == 6) {
data_indices = rg.make<v8::Slice>(data_indices,
curr_op->input_value(2),
curr_op->input_value(3),
curr_op->input_value(4),
curr_op->input_value(5));
} else if (curr_op->get_input_size() == 5) {
data_indices = rg.make<v8::Slice>(data_indices,
curr_op->input_value(2),
curr_op->input_value(3),
curr_op->input_value(4));
} else {
return false;
}
} else if (ov::as_type_ptr<GatherAssign>(curr_op)) {
data_indices = rg.make<v8::Gather>(data_indices, curr_op->input_value(2), curr_op->input_value(3));
} else {
return false;
}
}

// Scatter in flattened tensor with indices and flattened data to be inserted
auto data_to_insert_into_1d = rg.make<v1::Reshape>(data_to_insert_into, neg_one_1d, false);
auto data_indices_1d = rg.make<v1::Reshape>(data_indices, scattering_shape, false);
auto to_be_inserted_data_1d = rg.make<v1::Reshape>(base_op->input_value(1), neg_one_1d, false);
auto updated_data_1d =
rg.make<v3::ScatterNDUpdate>(data_to_insert_into_1d, data_indices_1d, to_be_inserted_data_1d);

// Reshape to initial shape
auto res_node = rg.make<v1::Reshape>(updated_data_1d, input_shape, false);
copy_runtime_info_and_name(base_op, rg.get());
start_op->output(0).replace(res_node);

return true;
};

auto m =
std::make_shared<ov::pass::pattern::Matcher>(reverse_op, "ov::frontend::pytorch::pass::ReversepropResolver");
this->register_matcher(m, callback);
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
27 changes: 27 additions & 0 deletions src/frontends/pytorch/src/transforms/reverseprop_resolver.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

/**
* Replace sequence of reverseprop operations with ScatterNdUpdate.
*/
class ReversepropResolver : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::ReversepropResolver");
ReversepropResolver();
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
Loading

0 comments on commit c08e01d

Please sign in to comment.