Skip to content

Commit

Permalink
Support LSTMSequence with -1 sequence length (#20935)
Browse files Browse the repository at this point in the history
* [GPU] Support LSTMSequence w/ -1 seq_length

Co-authored-by:Taylor Yeonbok Lee <[email protected]>
Co-authored-by:Andrew Park <[email protected]>

* Fix GetInputInfo to retrieve input pid from LSTMCell

* LSTMCell use ov::PartialShape instead of cldnn::tensor
* implement lstm_elt_inst::calc_output_layouts
* implement lstm_elt_impl::static_canonicalize_shapes

* Add functional tests

* Fix unit test failure

---------

Co-authored-by: Andrew Park <[email protected]>
  • Loading branch information
ahnyoung-paul and andrew-k-park authored Nov 11, 2023
1 parent c08e01d commit 51da30b
Show file tree
Hide file tree
Showing 6 changed files with 497 additions and 72 deletions.
46 changes: 46 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/lstm_elt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,52 @@ struct lstm_elt_impl : typed_primitive_impl_ocl<lstm_elt> {

return {params, optional_params};
}

static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) {
if (impl_params.get_input_layout().get_partial_shape().size() != 2) {
return primitive_impl::static_canonicalize_shapes(impl_params);
}
auto updated_impl_params = canonicalize_fused_shapes(impl_params);

auto& input_layout = updated_impl_params.input_layouts[0];
auto& weights_layout = updated_impl_params.input_layouts[1];
auto& output_layout = updated_impl_params.output_layouts[0];

auto input_pshape = input_layout.get_partial_shape();
auto weights_pshape = weights_layout.get_partial_shape();
auto output_pshape = output_layout.get_partial_shape();

auto lstm_input_size = static_cast<cldnn::tensor::value_type>(input_pshape[1].get_length());
auto lstm_batch_size = static_cast<cldnn::tensor::value_type>(input_pshape[0].get_length());
auto lstm_hidden_size = static_cast<cldnn::tensor::value_type>(lstm_input_size / 4);

GPU_DEBUG_LOG << "lstm_input_size : " << lstm_input_size << std::endl;
GPU_DEBUG_LOG << "lstm_batch_size : " << lstm_batch_size << std::endl;
GPU_DEBUG_LOG << "lstm_hidden_size : " << lstm_hidden_size << std::endl;

GPU_DEBUG_LOG << "origin input_pshape : " << input_layout.to_short_string() << std::endl;
GPU_DEBUG_LOG << "origin weights_layout : " << weights_layout.to_short_string() << std::endl;

input_pshape = {lstm_batch_size, 1, 1, lstm_input_size};
input_layout.set_partial_shape(input_pshape);

weights_pshape = {lstm_batch_size, 1, 1, lstm_hidden_size}; // {batch, direction, 1, hidden_size}
weights_layout.format = format::adjust_to_rank(weights_layout.format, weights_pshape.size());
weights_layout.set_partial_shape(weights_pshape);

updated_impl_params.weights_layout = weights_layout;

GPU_DEBUG_LOG << "input_layout : " << input_layout.to_short_string() << std::endl;
GPU_DEBUG_LOG << "weights_layout : " << weights_layout.to_short_string() << std::endl;
GPU_DEBUG_LOG << "output_layout : " << output_layout.to_short_string() << std::endl;

OPENVINO_ASSERT(input_pshape.size() == 4 && weights_pshape.size() == 4, "input and weights shape should be rank 4");
return updated_impl_params;
}

kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
return static_canonicalize_shapes(impl_params);
}
};

namespace detail {
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/lstm_elt_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class typed_primitive_inst<lstm_elt> : public typed_primitive_inst_base<lstm_elt
using parent::parent;

public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(lstm_elt_node const& node, kernel_impl_params const& impl_param);
static layout calc_output_layout(lstm_elt_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(lstm_elt_node const& node);

Expand Down
19 changes: 19 additions & 0 deletions src/plugins/intel_gpu/src/graph/lstm_elt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ layout lstm_elt_inst::calc_output_layout(lstm_elt_node const& node, kernel_impl_
return result;
}

template<typename ShapeType>
std::vector<layout> lstm_elt_inst::calc_output_layouts(lstm_elt_node const& node, kernel_impl_params const& impl_param) {
std::vector<layout> output_layouts;

// input partial shape [batch, input_size (= hidden_size * 4)]
auto input_layout = impl_param.get_input_layout();
auto input_pshape = input_layout.get_partial_shape();
OPENVINO_ASSERT(static_cast<bool>(impl_param.desc->output_data_types[0]) == false, "Output data type forcing is not supported for lstm_elt_node!");
OPENVINO_ASSERT(input_pshape.rank().get_length() == 2, "input_layout rank should be 2 on dynamic shape.");

auto lstm_input_size = static_cast<cldnn::tensor::value_type>(input_pshape[1].get_length());
auto lstm_batch_size = static_cast<cldnn::tensor::value_type>(input_pshape[0].get_length());
auto lstm_hidden_size = static_cast<cldnn::tensor::value_type>(lstm_input_size / 4);

return {cldnn::layout{ov::PartialShape{lstm_batch_size, 2, 1, lstm_hidden_size}, input_layout.data_type, input_layout.format}};
}

template std::vector<layout> lstm_elt_inst::calc_output_layouts<ov::PartialShape>(lstm_elt_node const& node, const kernel_impl_params& impl_param);

std::string lstm_elt_inst::to_string(lstm_elt_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();
Expand Down
172 changes: 102 additions & 70 deletions src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,9 @@ static void CreateLSTMCellOp(ProgramBuilder& p, const std::shared_ptr<ov::op::v4
const auto in_dims0 = op->get_input_shape(0);
const auto out_dims0 = op->get_output_shape(0);

if (in_dims0.size() != 2 ||
op->get_input_shape(1).size() != 2 ||
op->get_input_shape(2).size() != 2)
OPENVINO_THROW("Wrong input shapes for LSTMCell op ", op->get_friendly_name());
OPENVINO_ASSERT((op->get_input_shape(0).size() == 2 &&
op->get_input_shape(1).size() == 2 &&
op->get_input_shape(2).size() == 2), "Wrong input shapes for LSTMCell op ", op->get_friendly_name());

lstm_input_size = static_cast<int>(in_dims0.back());
lstm_batch_size = static_cast<int>(in_dims0.at(in_dims0.size()-2));
Expand All @@ -91,69 +90,102 @@ static void CreateLSTMCellOp(ProgramBuilder& p, const std::shared_ptr<ov::op::v4
GetLSTMActivationParams(op, activations, activation_params);
float clip = op->get_clip();

// LSTM primitive works with single precision for all in/out/weights tensors
auto lstm_dtype = cldnn::element_type_to_data_type(op->get_output_element_type(0));

cldnn::primitive_id inReshapeID = layerName + "_inReshape";
cldnn::primitive_id permuteID = layerName + "_inputReorder";
cldnn::primitive_id inHiddenReshapeID = layerName + "_inHiddenReshape";
cldnn::primitive_id inHiddenReorderID = layerName + "_inHiddenReorder";
cldnn::primitive_id gemmReshapeID = layerName + "_gemmReshape";
cldnn::primitive_id gemmReorderID = layerName + "_gemmReorder";
cldnn::primitive_id input_concatID = layerName + "_inputConcat";

cldnn::tensor inputShape = { lstm_batch_size, 1, lstm_input_size, 1 };
cldnn::tensor inStateShape = { lstm_batch_size, 1, lstm_hidden_size, 1 };
cldnn::layout inputLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, inputShape);
cldnn::layout hiddenLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, inStateShape);
p.add_primitive(*op, cldnn::reshape(inReshapeID, inputs[0], inputShape));
p.add_primitive(*op, cldnn::reorder(permuteID, inReshapeID, inputLayout));


std::string hiddenInResh = inHiddenReshapeID + "_1";
std::string hiddenInStr = inHiddenReorderID + "_1";
std::string cellInResh = inHiddenReshapeID + "_2";
std::string cellInStr = inHiddenReorderID + "_2";
p.add_primitive(*op, cldnn::reshape(hiddenInResh, inputs[1], inStateShape));
p.add_primitive(*op, cldnn::reorder(hiddenInStr, cldnn::input_info(hiddenInResh), hiddenLayout));
p.add_primitive(*op, cldnn::reshape(cellInResh, inputs[2], inStateShape));
p.add_primitive(*op, cldnn::reorder(cellInStr, cldnn::input_info(cellInResh), hiddenLayout));
p.add_primitive(*op, cldnn::concatenation(input_concatID,
{ permuteID, hiddenInStr },
3));

cldnn::tensor gemmSz = cldnn::tensor{ lstm_batch_size, 1, 4 * lstm_hidden_size, 1 };
cldnn::layout gemmLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, gemmSz);
cldnn::tensor hiddenSz = cldnn::tensor{ lstm_batch_size, 1, lstm_hidden_size, 1 };
cldnn::tensor cellCropSz = cldnn::tensor{0, 1, 0, 0};

std::string lstm_fc_id = layerName + "_fully_connected";
std::string lstm_elt_id = layerName + "_lstm_elt";

cldnn::primitive_id WRconcatID = layerName + "_WRconcat";
p.add_primitive(*op, cldnn::concatenation(WRconcatID, { weight, recurrent }, 1));

cldnn::primitive_id FCInputReshapeID = "Reshape_bf_" + lstm_fc_id + "_for_input";
cldnn::tensor FCInputReshapeSz = { lstm_batch_size, inputShape.spatial[0] + inStateShape.spatial[0], 1, 1 };
p.add_primitive(*op, cldnn::reshape(FCInputReshapeID, cldnn::input_info(input_concatID), FCInputReshapeSz));

p.add_primitive(*op, cldnn::fully_connected(lstm_fc_id, cldnn::input_info(FCInputReshapeID), WRconcatID, bias.pid));
p.add_primitive(*op, cldnn::reshape(gemmReshapeID, cldnn::input_info(lstm_fc_id), gemmSz));
p.add_primitive(*op, cldnn::reorder(gemmReorderID, cldnn::input_info(gemmReshapeID), gemmLayout));
p.add_primitive(*op, cldnn::lstm_elt(lstm_elt_id, cldnn::input_info(gemmReorderID), cellInStr, clip, 0, activations,
activation_params, cldnn::lstm_weights_order::fizo, 0));


cldnn::tensor outSz = cldnn::tensor{ lstm_batch_size, lstm_hidden_size, 1, 1 };
cldnn::primitive_id outputHiddenCropID = layerName + "_hc";
cldnn::primitive_id outputHiddenID = layerName + ".out0";
p.add_primitive(*op, cldnn::crop(outputHiddenCropID, cldnn::input_info(lstm_elt_id), hiddenSz, cldnn::tensor{0, 0, 0, 0}));
p.add_primitive(*op, cldnn::reshape(outputHiddenID, cldnn::input_info(outputHiddenCropID), outSz), {layerName});
if (p.use_new_shape_infer()) {
cldnn::primitive_id input_concatID = layerName + "_inputConcat";
p.add_primitive(*op, cldnn::concatenation(input_concatID, { inputs[0], inputs[1] }, 1));

cldnn::primitive_id lstm_fc_id = layerName + "_fully_connected";
cldnn::primitive_id lstm_elt_id = layerName + "_lstm_elt";
cldnn::primitive_id wr_concat_id = layerName + "_WRconcat";
p.add_primitive(*op, cldnn::concatenation(wr_concat_id, { inputs[3], inputs[4] }, 1));
p.add_primitive(*op, cldnn::fully_connected(lstm_fc_id, cldnn::input_info(input_concatID), wr_concat_id, bias.pid));
p.add_primitive(*op, cldnn::lstm_elt(lstm_elt_id, cldnn::input_info(lstm_fc_id), inputs[2].pid, clip, 0, activations,
activation_params, cldnn::lstm_weights_order::fizo, 0));

auto outSz = op->get_output_partial_shape(0).to_shape();
std::vector<int64_t> outSzPt;
for (auto i : outSz) {
outSzPt.push_back(i);
}

cldnn::primitive_id outputCellCropID = layerName + "_cc";
cldnn::primitive_id outputCellID = layerName + ".out1";
p.add_primitive(*op, cldnn::crop(outputCellCropID, cldnn::input_info(lstm_elt_id), hiddenSz, cellCropSz));
p.add_primitive(*op, cldnn::reshape(outputCellID, cldnn::input_info(outputCellCropID), outSz));
cldnn::tensor hiddenSz = cldnn::tensor{ lstm_batch_size, 1, lstm_hidden_size, 1 };

cldnn::primitive_id outputHiddenCropID = layerName + "_hc";
cldnn::primitive_id outputHiddenID = layerName + ".out0";
p.add_primitive(*op, cldnn::crop(outputHiddenCropID, cldnn::input_info(lstm_elt_id), hiddenSz, cldnn::tensor{0, 0, 0, 0}));
p.add_primitive(*op, cldnn::reshape(outputHiddenID, cldnn::input_info(outputHiddenCropID),
false, outSzPt, op->get_output_partial_shape(0)), {layerName});

cldnn::primitive_id outputCellCropID = layerName + "_cc";
cldnn::primitive_id outputCellID = layerName + ".out1";
p.add_primitive(*op, cldnn::crop(outputCellCropID, cldnn::input_info(lstm_elt_id), hiddenSz, cldnn::tensor{0, 1, 0, 0}));
p.add_primitive(*op, cldnn::reshape(outputCellID, cldnn::input_info(outputCellCropID),
false, outSzPt, op->get_output_partial_shape(1)));
} else {
// LSTM primitive works with single precision for all in/out/weights tensors
auto lstm_dtype = cldnn::element_type_to_data_type(op->get_output_element_type(0));

cldnn::primitive_id inReshapeID = layerName + "_inReshape";
cldnn::primitive_id permuteID = layerName + "_inputReorder";
cldnn::primitive_id inHiddenReshapeID = layerName + "_inHiddenReshape";
cldnn::primitive_id inHiddenReorderID = layerName + "_inHiddenReorder";
cldnn::primitive_id gemmReshapeID = layerName + "_gemmReshape";
cldnn::primitive_id gemmReorderID = layerName + "_gemmReorder";
cldnn::primitive_id input_concatID = layerName + "_inputConcat";

cldnn::tensor inputShape = { lstm_batch_size, 1, lstm_input_size, 1 };
cldnn::tensor inStateShape = { lstm_batch_size, 1, lstm_hidden_size, 1 };
cldnn::layout inputLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, inputShape);
cldnn::layout hiddenLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, inStateShape);
p.add_primitive(*op, cldnn::reshape(inReshapeID, inputs[0], inputShape));
p.add_primitive(*op, cldnn::reorder(permuteID, inReshapeID, inputLayout));


std::string hiddenInResh = inHiddenReshapeID + "_1";
std::string hiddenInStr = inHiddenReorderID + "_1";
std::string cellInResh = inHiddenReshapeID + "_2";
std::string cellInStr = inHiddenReorderID + "_2";
p.add_primitive(*op, cldnn::reshape(hiddenInResh, inputs[1], inStateShape));
p.add_primitive(*op, cldnn::reorder(hiddenInStr, cldnn::input_info(hiddenInResh), hiddenLayout));
p.add_primitive(*op, cldnn::reshape(cellInResh, inputs[2], inStateShape));
p.add_primitive(*op, cldnn::reorder(cellInStr, cldnn::input_info(cellInResh), hiddenLayout));
p.add_primitive(*op, cldnn::concatenation(input_concatID,
{ permuteID, hiddenInStr },
3));

cldnn::tensor gemmSz = cldnn::tensor{ lstm_batch_size, 1, 4 * lstm_hidden_size, 1 };
cldnn::layout gemmLayout = cldnn::layout(lstm_dtype, cldnn::format::bfyx, gemmSz);
cldnn::tensor hiddenSz = cldnn::tensor{ lstm_batch_size, 1, lstm_hidden_size, 1 };
cldnn::tensor cellCropSz = cldnn::tensor{0, 1, 0, 0};

std::string lstm_fc_id = layerName + "_fully_connected";
std::string lstm_elt_id = layerName + "_lstm_elt";

cldnn::primitive_id WRconcatID = layerName + "_WRconcat";
p.add_primitive(*op, cldnn::concatenation(WRconcatID, { weight, recurrent }, 1));

cldnn::primitive_id FCInputReshapeID = "Reshape_bf_" + lstm_fc_id + "_for_input";
cldnn::tensor FCInputReshapeSz = { lstm_batch_size, inputShape.spatial[0] + inStateShape.spatial[0], 1, 1 };
p.add_primitive(*op, cldnn::reshape(FCInputReshapeID, cldnn::input_info(input_concatID), FCInputReshapeSz));

p.add_primitive(*op, cldnn::fully_connected(lstm_fc_id, cldnn::input_info(FCInputReshapeID), WRconcatID, bias.pid));
p.add_primitive(*op, cldnn::reshape(gemmReshapeID, cldnn::input_info(lstm_fc_id), gemmSz));
p.add_primitive(*op, cldnn::reorder(gemmReorderID, cldnn::input_info(gemmReshapeID), gemmLayout));
p.add_primitive(*op, cldnn::lstm_elt(lstm_elt_id, cldnn::input_info(gemmReorderID), cellInStr, clip, 0, activations,
activation_params, cldnn::lstm_weights_order::fizo, 0));


cldnn::tensor outSz = cldnn::tensor{ lstm_batch_size, lstm_hidden_size, 1, 1 };
cldnn::primitive_id outputHiddenCropID = layerName + "_hc";
cldnn::primitive_id outputHiddenID = layerName + ".out0";
p.add_primitive(*op, cldnn::crop(outputHiddenCropID, cldnn::input_info(lstm_elt_id), hiddenSz, cldnn::tensor{0, 0, 0, 0}));
p.add_primitive(*op, cldnn::reshape(outputHiddenID, cldnn::input_info(outputHiddenCropID), outSz), {layerName});

cldnn::primitive_id outputCellCropID = layerName + "_cc";
cldnn::primitive_id outputCellID = layerName + ".out1";
p.add_primitive(*op, cldnn::crop(outputCellCropID, cldnn::input_info(lstm_elt_id), hiddenSz, cellCropSz));
p.add_primitive(*op, cldnn::reshape(outputCellID, cldnn::input_info(outputCellCropID), outSz));
}
}

static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op::v5::LSTMSequence>& op) {
Expand Down Expand Up @@ -217,12 +249,12 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
cldnn::primitive_id cellStr = inHiddenReshapeID + "_2";
cldnn::primitive_id inputCropID = layerName + "_inputCrop";

cldnn::primitive_id WRconcatID = layerName + "_WRconcat";
p.add_primitive(*op, cldnn::concatenation(WRconcatID, { weight, recurrent }, 2));
cldnn::primitive_id wr_concat_id = layerName + "_WRconcat";
p.add_primitive(*op, cldnn::concatenation(wr_concat_id, { weight, recurrent }, 2));

std::vector<size_t> WRreshapeSize = { 4 * size_t(lstm_hidden_size), size_t(lstm_input_size + lstm_hidden_size) };
cldnn::primitive_id WRreshapeID = WRconcatID + "_reshape";
auto reshapeInPrim = cldnn::reshape(WRreshapeID, cldnn::input_info(WRconcatID), tensor_from_dims(WRreshapeSize));
cldnn::primitive_id WRreshapeID = wr_concat_id + "_reshape";
auto reshapeInPrim = cldnn::reshape(WRreshapeID, cldnn::input_info(wr_concat_id), tensor_from_dims(WRreshapeSize));
p.add_primitive(*op, reshapeInPrim);

for (int i = 0; i < lstm_sequence_len; ++i) {
Expand Down
8 changes: 6 additions & 2 deletions src/plugins/intel_gpu/src/plugin/program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/op/constant.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/op/lstm_cell.hpp"

#include "intel_gpu/plugin/program_builder.hpp"
#include "intel_gpu/plugin/transformations_pipeline.hpp"
Expand Down Expand Up @@ -250,10 +251,13 @@ std::vector<cldnn::input_info> ProgramBuilder::GetInputInfo(const std::shared_pt
for (size_t i = 0; i < op->get_input_size(); i++) {
auto prevOp = op->get_input_node_ptr(i);
std::string prevName = layer_type_name_ID(prevOp);
// Note: Currently Split/Variadic Split are divided to multiple crops
// LSTMCell contains its own body network, and each output has a unique pid
// But there is no need to maintain output port index for the next node e.g. Result
bool is_legacy_multiple_outputs = !allow_new_shape_infer
// Note:: Currently Split/Variadic Split are divided to multiple crops
|| ov::is_type<ov::op::v1::Split>(prevOp)
|| ov::is_type<ov::op::v1::VariadicSplit>(prevOp);
|| ov::is_type<ov::op::v1::VariadicSplit>(prevOp)
|| ov::is_type<ov::op::v4::LSTMCell>(prevOp);
if (prevOp->get_output_size() > 1 && is_legacy_multiple_outputs) {
prevName += ".out" + std::to_string(op->get_input_source_output(i).get_index());
}
Expand Down
Loading

0 comments on commit 51da30b

Please sign in to comment.