From 311e3d1355ede10536f46c01c165aebfb1257ce8 Mon Sep 17 00:00:00 2001 From: mitruska Date: Tue, 21 Jan 2025 04:36:37 +0100 Subject: [PATCH] Init ISTFT op --- src/core/include/openvino/op/istft.hpp | 51 ++++++++ src/core/include/openvino/op/ops.hpp | 1 + .../include/openvino/opsets/opset16_tbl.hpp | 1 + .../include/istft_shape_inference.hpp | 112 ++++++++++++++++ src/core/src/op/istft.cpp | 123 ++++++++++++++++++ src/core/tests/opset.cpp | 2 +- src/core/tests/type_prop/istft.cpp | 76 +++++++++++ 7 files changed, 365 insertions(+), 1 deletion(-) create mode 100644 src/core/include/openvino/op/istft.hpp create mode 100644 src/core/shape_inference/include/istft_shape_inference.hpp create mode 100644 src/core/src/op/istft.cpp create mode 100644 src/core/tests/type_prop/istft.cpp diff --git a/src/core/include/openvino/op/istft.hpp b/src/core/include/openvino/op/istft.hpp new file mode 100644 index 00000000000000..611ee10316086b --- /dev/null +++ b/src/core/include/openvino/op/istft.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v16 { +/// \brief An operation ISTFT that computes the Inverse Short Time Fourier Transform. +/// \ingroup ov_ops_cpp_api +class OPENVINO_API ISTFT : public Op { +public: + OPENVINO_OP("ISTFT", "opset16"); + ISTFT() = default; + + /// \brief Constructs an ISTFT operation. + /// + /// \param data Input data + /// \param window Window values applied in STFT + /// \param frame_size Scalar value representing the size of Fourier Transform + /// \param frame_step The distance (number of samples) between successive window frames + /// \param length The length of the original signal + /// \param center Flag signaling if the signal input has been padded before STFT + /// \param normalized Flag signaling if the STFT result has been normalized. + ISTFT(const Output& data, + const Output& window, + const Output& frame_size, + const Output& frame_step, + const Output& length, + const bool center, + const bool normalized); + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + bool get_center() const; + + bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override; + bool has_evaluate() const override; + +private: + bool m_center = false; + bool m_normalized = false; +}; +} // namespace v16 +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/op/ops.hpp b/src/core/include/openvino/op/ops.hpp index 73510a524ef3e1..9cc7c683104a54 100644 --- a/src/core/include/openvino/op/ops.hpp +++ b/src/core/include/openvino/op/ops.hpp @@ -99,6 +99,7 @@ #include "openvino/op/is_finite.hpp" #include "openvino/op/is_inf.hpp" #include "openvino/op/is_nan.hpp" +#include "openvino/op/istft.hpp" #include "openvino/op/less.hpp" #include "openvino/op/less_eq.hpp" #include "openvino/op/log.hpp" diff --git a/src/core/include/openvino/opsets/opset16_tbl.hpp b/src/core/include/openvino/opsets/opset16_tbl.hpp index 4038aa17b72750..63afe3af430acb 100644 --- a/src/core/include/openvino/opsets/opset16_tbl.hpp +++ b/src/core/include/openvino/opsets/opset16_tbl.hpp @@ -15,3 +15,4 @@ _OPENVINO_OP_REG(ShapeOf, ov::op::v3) // New operations added in opset16 _OPENVINO_OP_REG(Identity, ov::op::v16) +_OPENVINO_OP_REG(ISTFT, ov::op::v16) diff --git a/src/core/shape_inference/include/istft_shape_inference.hpp b/src/core/shape_inference/include/istft_shape_inference.hpp new file mode 100644 index 00000000000000..d62ea60116b292 --- /dev/null +++ b/src/core/shape_inference/include/istft_shape_inference.hpp @@ -0,0 +1,112 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "dimension_util.hpp" +#include "openvino/op/istft.hpp" +#include "utils.hpp" + +namespace ov { +namespace op { +namespace v16 { +template > +std::vector shape_infer(const ISTFT* op, + const std::vector& input_shapes, + const ITensorAccessor& ta = make_tensor_accessor()) { + using TDim = typename TRShape::value_type; + using TDimVal = typename TDim::value_type; + + NODE_VALIDATION_CHECK(op, input_shapes.size() == 5); + + const auto& data_shape = input_shapes[0]; + const auto& window_shape = input_shapes[1]; + const auto& frame_size_shape = input_shapes[2]; + const auto& frame_step_shape = input_shapes[3]; + const auto& length_shape = input_shapes[4]; + + const auto data_shape_rank = data_shape.rank(); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + data_shape_rank.compatible(3) || data_shape_rank.compatible(4), + "The shape of data must be 3D [signal_size] or 4D [batch, signal_size]."); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + window_shape.rank().compatible(1), + "The shape of window must be 1D [window_size]."); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + frame_size_shape.rank().compatible(0), + "The shape of frame_size must be a scalar."); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + frame_step_shape.rank().compatible(0), + "The shape of frame_step must be a scalar."); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + length_shape.rank().compatible(0), + "The shape of length input must be a scalar."); + + if (data_shape_rank.is_dynamic()) { + return {data_shape}; + } + + const auto frame_size = get_input_const_data_as(op, 2, ta); + const auto frame_step = get_input_const_data_as(op, 3, ta); + + const auto is_data_3D = data_shape.size() == 3; + if (!frame_size || !frame_step) { + if (is_data_3D) { + return {TRShape{TDim(ov::util::dim::inf_bound)}}; + } else { + return {TRShape{data_shape[0], TDim(ov::util::dim::inf_bound)}}; + } + } + + const auto& frame_size_val = (*frame_size)[0]; + const auto& frame_step_val = (*frame_step)[0]; + + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + 0 < frame_size_val, + "Provided frame size is ", + frame_size_val, + " but must be greater than zero."); + + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + 0 < frame_step_val, + "Provided frame step is ", + frame_step_val, + " but must be greater than zero."); + + const bool is_win_shape_correct = + window_shape.is_dynamic() || (TDimVal{0} < window_shape[0].get_length() && + window_shape[0].get_length() <= static_cast(frame_size_val)); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + is_win_shape_correct, + "Window input dimension must be in range [1, ", + frame_size_val, + "]."); + + int64_t frames_axis = 1 + (is_data_3D ? 0 : 1); + const TDim& num_frames_dim = data_shape[frames_axis]; + TDim signal_length = (num_frames_dim - 1) * frame_step_val; + if (!op->get_center()) { + signal_length += frame_size_val; + } + + std::vector output_shapes; + output_shapes.emplace_back(TRShape{std::move(signal_length)}); + + if (!is_data_3D) { + const auto& batch_dim = data_shape[0]; + output_shapes[0].insert(output_shapes[0].begin(), batch_dim); + } + return output_shapes; +} +} // namespace v16 +} // namespace op +} // namespace ov diff --git a/src/core/src/op/istft.cpp b/src/core/src/op/istft.cpp new file mode 100644 index 00000000000000..f851b4bbb00ab2 --- /dev/null +++ b/src/core/src/op/istft.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/istft.hpp" + +#include + +#include "istft_shape_inference.hpp" +#include "itt.hpp" +// #include "openvino/reference/istft.hpp" + +namespace ov { +namespace op { +namespace v16 { +namespace { +void check_int_input_at(const Node* op, size_t input_idx) { + const auto& in_type = op->get_input_element_type(input_idx); + const auto has_valid_type = in_type.is_dynamic() || in_type == element::i32 || in_type == element::i64; + NODE_VALIDATION_CHECK(op, has_valid_type, "Expected i32 or i64 type of the input at port: ", input_idx); +} +} // namespace +ISTFT::ISTFT(const Output& data, + const Output& window, + const Output& frame_size, + const Output& frame_step, + const Output& length, + const bool center, + const bool normalized) + : Op({data, window, frame_size, frame_step, length}), + m_center(center), + m_normalized(normalized) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr ISTFT::clone_with_new_inputs(const OutputVector& new_args) const { + OV_OP_SCOPE(v16_ISTFT_clone_with_new_inputs); + check_new_args_count(this, new_args); + + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + m_center, + m_normalized); +} + +bool ISTFT::visit_attributes(AttributeVisitor& visitor) { + OV_OP_SCOPE(v16_ISTFT_visit_attributes); + visitor.on_attribute("center", m_center); + visitor.on_attribute("normalized", m_normalized); + return true; +} + +void ISTFT::validate_and_infer_types() { + OV_OP_SCOPE(v16_ISTFT_validate_and_infer_types); + NODE_VALIDATION_CHECK(this, get_input_size() == 5, "Expected 5 inputs to be provided."); + + auto signal_type = get_input_element_type(0); + const auto& window_type = get_input_element_type(1); + + const auto has_valid_signal_type = signal_type.is_dynamic() || signal_type.is_real(); + NODE_VALIDATION_CHECK(this, has_valid_signal_type, "Expected floating point type of the 'signal' input."); + + const auto has_valid_window_type = + window_type.is_dynamic() || + (window_type.is_real() && element::Type::merge(signal_type, window_type, signal_type)); + NODE_VALIDATION_CHECK(this, + has_valid_window_type, + "Expected floating point type of the 'window' input, matching the type of `signal` input."); + + check_int_input_at(this, 2); + check_int_input_at(this, 3); + check_int_input_at(this, 4); + + const auto input_shapes = ov::util::get_node_input_partial_shapes(*this); + const auto output_shapes = shape_infer(this, input_shapes); + + set_output_type(0, signal_type, output_shapes[0]); +} + +bool ISTFT::evaluate(TensorVector& outputs, const TensorVector& inputs) const { + OV_OP_SCOPE(v16_ISTFT_evaluate); + OPENVINO_ASSERT(outputs.size() == 1); + OPENVINO_ASSERT(inputs.size() == 5); + + const auto input_shapes = ov::util::get_tensors_partial_shapes(inputs); + const auto output_shape = shape_infer(this, input_shapes, make_tensor_accessor(inputs)).front().to_shape(); + + outputs[0].set_shape(output_shape); + + const auto frame_size = ov::get_tensor_data_as(inputs[2]).front(); + const auto frame_step = ov::get_tensor_data_as(inputs[3]).front(); + const auto length = ov::get_tensor_data_as(inputs[4]).front(); + + // ov::reference::istft(inputs[0].data(), + // inputs[1].data(), + // outputs[0].data(), + // inputs[0].get_shape(), + // inputs[1].get_shape(), + // frame_size, + // frame_step, + // length, + // m_normalized, + // m_center); + return true; +} + +bool ISTFT::has_evaluate() const { + OV_OP_SCOPE(v16_ISTFT_has_evaluate); + const auto& input_0_et = get_input_element_type(0); + return input_0_et == element::f32; +} + +bool ISTFT::get_center() const { + OV_OP_SCOPE(v16_ISTFT_get_center); + return m_center; +} + +} // namespace v16 +} // namespace op +} // namespace ov diff --git a/src/core/tests/opset.cpp b/src/core/tests/opset.cpp index 81f6e80c28189f..47c8770a9b2e32 100644 --- a/src/core/tests/opset.cpp +++ b/src/core/tests/opset.cpp @@ -77,7 +77,7 @@ INSTANTIATE_TEST_SUITE_P(opset, OpsetTestParams{ov::get_opset13, 186}, OpsetTestParams{ov::get_opset14, 188}, OpsetTestParams{ov::get_opset15, 199}, - OpsetTestParams{ov::get_opset16, 4}), + OpsetTestParams{ov::get_opset16, 5}), OpsetTestNameGenerator{}); class MyOpOld : public ov::op::Op { diff --git a/src/core/tests/type_prop/istft.cpp b/src/core/tests/type_prop/istft.cpp new file mode 100644 index 00000000000000..afb83508a9d40e --- /dev/null +++ b/src/core/tests/type_prop/istft.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/istft.hpp" + +#include + +#include "common_test_utils/test_assertions.hpp" +#include "common_test_utils/type_prop.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" + +namespace ov { +namespace test { + +using op::v0::Constant; +using op::v0::Parameter; +using testing::HasSubstr; + +class TypePropISTFTTest : public TypePropOpTest { +public: + bool center = true; + bool normalized = false; +}; + +TEST_F(TypePropISTFTTest, all_inputs_as_params_static_shapes) { + const auto in_data = std::make_shared(element::f32, PartialShape{2, 1, 4, 48}); + const auto window = std::make_shared(element::f32, PartialShape{7}); + const auto frame_size = std::make_shared(element::i64, PartialShape{}); + const auto frame_step = std::make_shared(element::i64, PartialShape{}); + const auto length = std::make_shared(element::i64, PartialShape{}); + + const auto op = make_op(in_data, window, frame_size, frame_step, length, center, normalized); + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, -1})); +} + +using STFTTestParam = std::tuple; +class TypePropISTFTTestP : public TypePropISTFTTest, public testing::WithParamInterface { +protected: + void SetUp() override { + std::tie(signal_shape, window_shape, frame_size_val, step_size_val, center, expected_shape) = GetParam(); + } + PartialShape signal_shape, window_shape, expected_shape; + int32_t frame_size_val, step_size_val; +}; + +INSTANTIATE_TEST_SUITE_P( + type_prop_stft_shape, + TypePropISTFTTestP, + testing::Values( + std::make_tuple(PartialShape{16}, PartialShape{16}, 16, 16, false, PartialShape{9, 1, 2}), // frames at 1 + std::make_tuple(PartialShape{48}, PartialShape{16}, 16, 16, false, PartialShape{9, 3, 2}), + std::make_tuple(PartialShape{56}, PartialShape{7}, 11, 3, false, PartialShape{6, 16, 2}), + + std::make_tuple(PartialShape::dynamic(), PartialShape::dynamic(), 11, 3, true, PartialShape::dynamic())), + testing::PrintToStringParamName()); + +TEST_P(TypePropISTFTTestP, istft_shapes) { + const auto in_data = std::make_shared(element::f32, expected_shape); + const auto window = std::make_shared(element::f32, window_shape); + const auto frame_size = Constant::create(element::i32, {}, {frame_size_val}); + const auto frame_step = Constant::create(element::i32, {}, {step_size_val}); + const auto length = std::make_shared(element::i32, Shape{}); + + const auto op = make_op(in_data, window, frame_size, frame_step, length, center, false); + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), signal_shape); +} + +} // namespace test +} // namespace ov