From 311e3d1355ede10536f46c01c165aebfb1257ce8 Mon Sep 17 00:00:00 2001 From: mitruska Date: Tue, 21 Jan 2025 04:36:37 +0100 Subject: [PATCH 1/7] Init ISTFT op --- src/core/include/openvino/op/istft.hpp | 51 ++++++++ src/core/include/openvino/op/ops.hpp | 1 + .../include/openvino/opsets/opset16_tbl.hpp | 1 + .../include/istft_shape_inference.hpp | 112 ++++++++++++++++ src/core/src/op/istft.cpp | 123 ++++++++++++++++++ src/core/tests/opset.cpp | 2 +- src/core/tests/type_prop/istft.cpp | 76 +++++++++++ 7 files changed, 365 insertions(+), 1 deletion(-) create mode 100644 src/core/include/openvino/op/istft.hpp create mode 100644 src/core/shape_inference/include/istft_shape_inference.hpp create mode 100644 src/core/src/op/istft.cpp create mode 100644 src/core/tests/type_prop/istft.cpp diff --git a/src/core/include/openvino/op/istft.hpp b/src/core/include/openvino/op/istft.hpp new file mode 100644 index 00000000000000..611ee10316086b --- /dev/null +++ b/src/core/include/openvino/op/istft.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v16 { +/// \brief An operation ISTFT that computes the Inverse Short Time Fourier Transform. +/// \ingroup ov_ops_cpp_api +class OPENVINO_API ISTFT : public Op { +public: + OPENVINO_OP("ISTFT", "opset16"); + ISTFT() = default; + + /// \brief Constructs an ISTFT operation. + /// + /// \param data Input data + /// \param window Window values applied in STFT + /// \param frame_size Scalar value representing the size of Fourier Transform + /// \param frame_step The distance (number of samples) between successive window frames + /// \param length The length of the original signal + /// \param center Flag signaling if the signal input has been padded before STFT + /// \param normalized Flag signaling if the STFT result has been normalized. + ISTFT(const Output& data, + const Output& window, + const Output& frame_size, + const Output& frame_step, + const Output& length, + const bool center, + const bool normalized); + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + bool get_center() const; + + bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override; + bool has_evaluate() const override; + +private: + bool m_center = false; + bool m_normalized = false; +}; +} // namespace v16 +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/op/ops.hpp b/src/core/include/openvino/op/ops.hpp index 73510a524ef3e1..9cc7c683104a54 100644 --- a/src/core/include/openvino/op/ops.hpp +++ b/src/core/include/openvino/op/ops.hpp @@ -99,6 +99,7 @@ #include "openvino/op/is_finite.hpp" #include "openvino/op/is_inf.hpp" #include "openvino/op/is_nan.hpp" +#include "openvino/op/istft.hpp" #include "openvino/op/less.hpp" #include "openvino/op/less_eq.hpp" #include "openvino/op/log.hpp" diff --git a/src/core/include/openvino/opsets/opset16_tbl.hpp b/src/core/include/openvino/opsets/opset16_tbl.hpp index 4038aa17b72750..63afe3af430acb 100644 --- a/src/core/include/openvino/opsets/opset16_tbl.hpp +++ b/src/core/include/openvino/opsets/opset16_tbl.hpp @@ -15,3 +15,4 @@ _OPENVINO_OP_REG(ShapeOf, ov::op::v3) // New operations added in opset16 _OPENVINO_OP_REG(Identity, ov::op::v16) +_OPENVINO_OP_REG(ISTFT, ov::op::v16) diff --git a/src/core/shape_inference/include/istft_shape_inference.hpp b/src/core/shape_inference/include/istft_shape_inference.hpp new file mode 100644 index 00000000000000..d62ea60116b292 --- /dev/null +++ b/src/core/shape_inference/include/istft_shape_inference.hpp @@ -0,0 +1,112 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "dimension_util.hpp" +#include "openvino/op/istft.hpp" +#include "utils.hpp" + +namespace ov { +namespace op { +namespace v16 { +template > +std::vector shape_infer(const ISTFT* op, + const std::vector& input_shapes, + const ITensorAccessor& ta = make_tensor_accessor()) { + using TDim = typename TRShape::value_type; + using TDimVal = typename TDim::value_type; + + NODE_VALIDATION_CHECK(op, input_shapes.size() == 5); + + const auto& data_shape = input_shapes[0]; + const auto& window_shape = input_shapes[1]; + const auto& frame_size_shape = input_shapes[2]; + const auto& frame_step_shape = input_shapes[3]; + const auto& length_shape = input_shapes[4]; + + const auto data_shape_rank = data_shape.rank(); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + data_shape_rank.compatible(3) || data_shape_rank.compatible(4), + "The shape of data must be 3D [signal_size] or 4D [batch, signal_size]."); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + window_shape.rank().compatible(1), + "The shape of window must be 1D [window_size]."); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + frame_size_shape.rank().compatible(0), + "The shape of frame_size must be a scalar."); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + frame_step_shape.rank().compatible(0), + "The shape of frame_step must be a scalar."); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + length_shape.rank().compatible(0), + "The shape of length input must be a scalar."); + + if (data_shape_rank.is_dynamic()) { + return {data_shape}; + } + + const auto frame_size = get_input_const_data_as(op, 2, ta); + const auto frame_step = get_input_const_data_as(op, 3, ta); + + const auto is_data_3D = data_shape.size() == 3; + if (!frame_size || !frame_step) { + if (is_data_3D) { + return {TRShape{TDim(ov::util::dim::inf_bound)}}; + } else { + return {TRShape{data_shape[0], TDim(ov::util::dim::inf_bound)}}; + } + } + + const auto& frame_size_val = (*frame_size)[0]; + const auto& frame_step_val = (*frame_step)[0]; + + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + 0 < frame_size_val, + "Provided frame size is ", + frame_size_val, + " but must be greater than zero."); + + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + 0 < frame_step_val, + "Provided frame step is ", + frame_step_val, + " but must be greater than zero."); + + const bool is_win_shape_correct = + window_shape.is_dynamic() || (TDimVal{0} < window_shape[0].get_length() && + window_shape[0].get_length() <= static_cast(frame_size_val)); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + is_win_shape_correct, + "Window input dimension must be in range [1, ", + frame_size_val, + "]."); + + int64_t frames_axis = 1 + (is_data_3D ? 0 : 1); + const TDim& num_frames_dim = data_shape[frames_axis]; + TDim signal_length = (num_frames_dim - 1) * frame_step_val; + if (!op->get_center()) { + signal_length += frame_size_val; + } + + std::vector output_shapes; + output_shapes.emplace_back(TRShape{std::move(signal_length)}); + + if (!is_data_3D) { + const auto& batch_dim = data_shape[0]; + output_shapes[0].insert(output_shapes[0].begin(), batch_dim); + } + return output_shapes; +} +} // namespace v16 +} // namespace op +} // namespace ov diff --git a/src/core/src/op/istft.cpp b/src/core/src/op/istft.cpp new file mode 100644 index 00000000000000..f851b4bbb00ab2 --- /dev/null +++ b/src/core/src/op/istft.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/istft.hpp" + +#include + +#include "istft_shape_inference.hpp" +#include "itt.hpp" +// #include "openvino/reference/istft.hpp" + +namespace ov { +namespace op { +namespace v16 { +namespace { +void check_int_input_at(const Node* op, size_t input_idx) { + const auto& in_type = op->get_input_element_type(input_idx); + const auto has_valid_type = in_type.is_dynamic() || in_type == element::i32 || in_type == element::i64; + NODE_VALIDATION_CHECK(op, has_valid_type, "Expected i32 or i64 type of the input at port: ", input_idx); +} +} // namespace +ISTFT::ISTFT(const Output& data, + const Output& window, + const Output& frame_size, + const Output& frame_step, + const Output& length, + const bool center, + const bool normalized) + : Op({data, window, frame_size, frame_step, length}), + m_center(center), + m_normalized(normalized) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr ISTFT::clone_with_new_inputs(const OutputVector& new_args) const { + OV_OP_SCOPE(v16_ISTFT_clone_with_new_inputs); + check_new_args_count(this, new_args); + + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + m_center, + m_normalized); +} + +bool ISTFT::visit_attributes(AttributeVisitor& visitor) { + OV_OP_SCOPE(v16_ISTFT_visit_attributes); + visitor.on_attribute("center", m_center); + visitor.on_attribute("normalized", m_normalized); + return true; +} + +void ISTFT::validate_and_infer_types() { + OV_OP_SCOPE(v16_ISTFT_validate_and_infer_types); + NODE_VALIDATION_CHECK(this, get_input_size() == 5, "Expected 5 inputs to be provided."); + + auto signal_type = get_input_element_type(0); + const auto& window_type = get_input_element_type(1); + + const auto has_valid_signal_type = signal_type.is_dynamic() || signal_type.is_real(); + NODE_VALIDATION_CHECK(this, has_valid_signal_type, "Expected floating point type of the 'signal' input."); + + const auto has_valid_window_type = + window_type.is_dynamic() || + (window_type.is_real() && element::Type::merge(signal_type, window_type, signal_type)); + NODE_VALIDATION_CHECK(this, + has_valid_window_type, + "Expected floating point type of the 'window' input, matching the type of `signal` input."); + + check_int_input_at(this, 2); + check_int_input_at(this, 3); + check_int_input_at(this, 4); + + const auto input_shapes = ov::util::get_node_input_partial_shapes(*this); + const auto output_shapes = shape_infer(this, input_shapes); + + set_output_type(0, signal_type, output_shapes[0]); +} + +bool ISTFT::evaluate(TensorVector& outputs, const TensorVector& inputs) const { + OV_OP_SCOPE(v16_ISTFT_evaluate); + OPENVINO_ASSERT(outputs.size() == 1); + OPENVINO_ASSERT(inputs.size() == 5); + + const auto input_shapes = ov::util::get_tensors_partial_shapes(inputs); + const auto output_shape = shape_infer(this, input_shapes, make_tensor_accessor(inputs)).front().to_shape(); + + outputs[0].set_shape(output_shape); + + const auto frame_size = ov::get_tensor_data_as(inputs[2]).front(); + const auto frame_step = ov::get_tensor_data_as(inputs[3]).front(); + const auto length = ov::get_tensor_data_as(inputs[4]).front(); + + // ov::reference::istft(inputs[0].data(), + // inputs[1].data(), + // outputs[0].data(), + // inputs[0].get_shape(), + // inputs[1].get_shape(), + // frame_size, + // frame_step, + // length, + // m_normalized, + // m_center); + return true; +} + +bool ISTFT::has_evaluate() const { + OV_OP_SCOPE(v16_ISTFT_has_evaluate); + const auto& input_0_et = get_input_element_type(0); + return input_0_et == element::f32; +} + +bool ISTFT::get_center() const { + OV_OP_SCOPE(v16_ISTFT_get_center); + return m_center; +} + +} // namespace v16 +} // namespace op +} // namespace ov diff --git a/src/core/tests/opset.cpp b/src/core/tests/opset.cpp index 81f6e80c28189f..47c8770a9b2e32 100644 --- a/src/core/tests/opset.cpp +++ b/src/core/tests/opset.cpp @@ -77,7 +77,7 @@ INSTANTIATE_TEST_SUITE_P(opset, OpsetTestParams{ov::get_opset13, 186}, OpsetTestParams{ov::get_opset14, 188}, OpsetTestParams{ov::get_opset15, 199}, - OpsetTestParams{ov::get_opset16, 4}), + OpsetTestParams{ov::get_opset16, 5}), OpsetTestNameGenerator{}); class MyOpOld : public ov::op::Op { diff --git a/src/core/tests/type_prop/istft.cpp b/src/core/tests/type_prop/istft.cpp new file mode 100644 index 00000000000000..afb83508a9d40e --- /dev/null +++ b/src/core/tests/type_prop/istft.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/istft.hpp" + +#include + +#include "common_test_utils/test_assertions.hpp" +#include "common_test_utils/type_prop.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" + +namespace ov { +namespace test { + +using op::v0::Constant; +using op::v0::Parameter; +using testing::HasSubstr; + +class TypePropISTFTTest : public TypePropOpTest { +public: + bool center = true; + bool normalized = false; +}; + +TEST_F(TypePropISTFTTest, all_inputs_as_params_static_shapes) { + const auto in_data = std::make_shared(element::f32, PartialShape{2, 1, 4, 48}); + const auto window = std::make_shared(element::f32, PartialShape{7}); + const auto frame_size = std::make_shared(element::i64, PartialShape{}); + const auto frame_step = std::make_shared(element::i64, PartialShape{}); + const auto length = std::make_shared(element::i64, PartialShape{}); + + const auto op = make_op(in_data, window, frame_size, frame_step, length, center, normalized); + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, -1})); +} + +using STFTTestParam = std::tuple; +class TypePropISTFTTestP : public TypePropISTFTTest, public testing::WithParamInterface { +protected: + void SetUp() override { + std::tie(signal_shape, window_shape, frame_size_val, step_size_val, center, expected_shape) = GetParam(); + } + PartialShape signal_shape, window_shape, expected_shape; + int32_t frame_size_val, step_size_val; +}; + +INSTANTIATE_TEST_SUITE_P( + type_prop_stft_shape, + TypePropISTFTTestP, + testing::Values( + std::make_tuple(PartialShape{16}, PartialShape{16}, 16, 16, false, PartialShape{9, 1, 2}), // frames at 1 + std::make_tuple(PartialShape{48}, PartialShape{16}, 16, 16, false, PartialShape{9, 3, 2}), + std::make_tuple(PartialShape{56}, PartialShape{7}, 11, 3, false, PartialShape{6, 16, 2}), + + std::make_tuple(PartialShape::dynamic(), PartialShape::dynamic(), 11, 3, true, PartialShape::dynamic())), + testing::PrintToStringParamName()); + +TEST_P(TypePropISTFTTestP, istft_shapes) { + const auto in_data = std::make_shared(element::f32, expected_shape); + const auto window = std::make_shared(element::f32, window_shape); + const auto frame_size = Constant::create(element::i32, {}, {frame_size_val}); + const auto frame_step = Constant::create(element::i32, {}, {step_size_val}); + const auto length = std::make_shared(element::i32, Shape{}); + + const auto op = make_op(in_data, window, frame_size, frame_step, length, center, false); + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), signal_shape); +} + +} // namespace test +} // namespace ov From 3fcb1ee3fafd7fe3a64da19fd946fb499b5c2502 Mon Sep 17 00:00:00 2001 From: mitruska Date: Fri, 24 Jan 2025 08:32:31 +0100 Subject: [PATCH 2/7] Init ISTFT reference and ref tests --- .../include/openvino/reference/istft.hpp | 22 + src/core/reference/src/op/istft.cpp | 162 +++++++ src/core/src/op/istft.cpp | 22 +- .../tests/functional/op_reference/istft.cpp | 396 ++++++++++++++++++ 4 files changed, 591 insertions(+), 11 deletions(-) create mode 100644 src/core/reference/include/openvino/reference/istft.hpp create mode 100644 src/core/reference/src/op/istft.cpp create mode 100644 src/plugins/template/tests/functional/op_reference/istft.cpp diff --git a/src/core/reference/include/openvino/reference/istft.hpp b/src/core/reference/include/openvino/reference/istft.hpp new file mode 100644 index 00000000000000..725548d2e70452 --- /dev/null +++ b/src/core/reference/include/openvino/reference/istft.hpp @@ -0,0 +1,22 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/shape.hpp" + +namespace ov { +namespace reference { +void istft(const float* in_data, + const float* window, + float* final_result, + const Shape& signal_shape, + const Shape& window_shape, + const int64_t frame_size, + const int64_t frame_step, + const int64_t length, + const bool center, + const bool normalized); +} // namespace reference +} // namespace ov diff --git a/src/core/reference/src/op/istft.cpp b/src/core/reference/src/op/istft.cpp new file mode 100644 index 00000000000000..c7eb144f8535cc --- /dev/null +++ b/src/core/reference/src/op/istft.cpp @@ -0,0 +1,162 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/reference/istft.hpp" + +#include +#include +#include + +#include "openvino/core/shape.hpp" +#include "openvino/reference/add.hpp" +#include "openvino/reference/divide.hpp" +#include "openvino/reference/fft.hpp" +#include "openvino/reference/irdft.hpp" +#include "openvino/reference/multiply.hpp" +#include "openvino/reference/transpose.hpp" + +namespace ov { +namespace reference { +void istft(const float* in_data, + const float* window, + float* final_result, + const Shape& data_shape, + const Shape& window_shape, + const int64_t frame_size, + const int64_t frame_step, + const int64_t length, + const bool center, + const bool normalized) { + const auto is_data_3D = data_shape.size() == 3; + const size_t frames_axis = 1 + (is_data_3D ? 0 : 1); + const size_t batch_size = is_data_3D ? 1 : data_shape[0]; + + const auto num_frames = data_shape[frames_axis]; + + const auto signal_length = (num_frames - 1) * frame_step + frame_size; + const auto final_signal_length = length > 0 ? length : (center ? (signal_length - frame_size) : signal_length); + + // auto signal_length = (num_frames - 1) * frame_step; + // if (!center) { + // signal_length += frame_size; + // } + // const auto final_signal_length = length > 0 ? length : signal_length; + + std::vector mid_result(batch_size * signal_length, 0); + float* result = mid_result.data(); + + const auto frame_size_dim = static_cast(frame_size); + const auto frame_size_dim_shape = Shape{frame_size_dim}; + const auto frame_size_dim_shape_out = Shape{frame_size_dim, 2}; + const auto fft_out_shape = Shape{static_cast((frame_size_dim / 2) + 1), 2}; + + const auto window_length = window_shape[0] < frame_size_dim ? window_shape[0] : frame_size_dim; + std::vector pad_window(frame_size, 0); + std::copy(window, window + window_shape[0], pad_window.begin() + (frame_size_dim - window_length) / 2); + + const bool transpose_frames = true; + std::vector data_t(in_data, in_data + shape_size(data_shape)); + if (transpose_frames) { + const auto stft_transp_out_shape = Shape{batch_size, num_frames, fft_out_shape[0], fft_out_shape[1]}; + transpose(reinterpret_cast(in_data), + reinterpret_cast(data_t.data()), + Shape{batch_size, fft_out_shape[0], num_frames, fft_out_shape[1]}, + sizeof(float), + {0, 2, 1, 3}, + stft_transp_out_shape); + } + + const auto fft_out_shape_size = shape_size(fft_out_shape); + + std::vector window_sum(batch_size * signal_length); + + for (size_t batch = 0, batch_in_start = 0, batch_out_start = 0; batch < batch_size; ++batch) { + for (size_t frame_idx = 0; frame_idx < num_frames; ++frame_idx) { + const auto in_frame_start = batch_in_start + frame_idx * fft_out_shape_size; + const auto in_frame_end = in_frame_start + fft_out_shape_size; + + const auto out_frame_start = batch_out_start + frame_idx * frame_step; + const auto out_frame_end = out_frame_start + frame_size; + + std::vector frame_data(data_t.data() + in_frame_start, data_t.data() + in_frame_end); + std::vector frame_signal(frame_size); + + reference::irdft(frame_data, + fft_out_shape, + {0}, + frame_signal.data(), + frame_size_dim_shape_out, + frame_size_dim_shape, + frame_size); + + reference::add(result + out_frame_start, + frame_signal.data(), + result + out_frame_start, + frame_size_dim_shape, + frame_size_dim_shape, + op::AutoBroadcastType::NUMPY); + + std::transform(window_sum.begin() + out_frame_start, + window_sum.begin() + out_frame_start + frame_size, + pad_window.begin(), + window_sum.begin() + out_frame_start, + std::plus()); + + // std::transform(result + out_frame_start, + // result + out_frame_start + frame_size, + // pad_window.begin(), + // result + out_frame_start, + // [](float a, float b) { + // if (b != 0.f) + // return a / b; + // else + // return 0.f; + // }); + + // reference::multiply(result + out_frame_start, + // pad_window.data(), + // result + out_frame_start, + // frame_size_dim_shape, + // frame_size_dim_shape, + // op::AutoBroadcastType::NUMPY); + } + + // std::transform(result + batch_out_start, + // result + batch_out_start + signal_length, + // window_sum.begin(), + // result + batch_out_start, + // [](float a, float b) { + // if (b != 0.f) + // return a / (b*b); + // else + // return 0.f; + // }); + + std::transform(result + batch_out_start, + result + batch_out_start + signal_length, + window_sum.begin(), + result + batch_out_start, + [](float a, float b) { + if (b != 0.f) + return a / b; + else + return 0.f; + }); + + if (center) { + std::copy(result + batch_out_start + (frame_size / 2), + result + batch_out_start + (frame_size / 2) + final_signal_length, + final_result + (batch * final_signal_length)); + } else { + std::copy(result + batch_out_start, + result + batch_out_start + final_signal_length, + final_result + batch_out_start); + } + + batch_in_start += (num_frames * fft_out_shape_size); + batch_out_start += signal_length; + } +} +} // namespace reference +} // namespace ov diff --git a/src/core/src/op/istft.cpp b/src/core/src/op/istft.cpp index f851b4bbb00ab2..d6bf60cc3330ce 100644 --- a/src/core/src/op/istft.cpp +++ b/src/core/src/op/istft.cpp @@ -8,7 +8,7 @@ #include "istft_shape_inference.hpp" #include "itt.hpp" -// #include "openvino/reference/istft.hpp" +#include "openvino/reference/istft.hpp" namespace ov { namespace op { @@ -94,16 +94,16 @@ bool ISTFT::evaluate(TensorVector& outputs, const TensorVector& inputs) const { const auto frame_step = ov::get_tensor_data_as(inputs[3]).front(); const auto length = ov::get_tensor_data_as(inputs[4]).front(); - // ov::reference::istft(inputs[0].data(), - // inputs[1].data(), - // outputs[0].data(), - // inputs[0].get_shape(), - // inputs[1].get_shape(), - // frame_size, - // frame_step, - // length, - // m_normalized, - // m_center); + ov::reference::istft(inputs[0].data(), + inputs[1].data(), + outputs[0].data(), + inputs[0].get_shape(), + inputs[1].get_shape(), + frame_size, + frame_step, + length, + m_center, + m_normalized); return true; } diff --git a/src/plugins/template/tests/functional/op_reference/istft.cpp b/src/plugins/template/tests/functional/op_reference/istft.cpp new file mode 100644 index 00000000000000..35f198c04032cd --- /dev/null +++ b/src/plugins/template/tests/functional/op_reference/istft.cpp @@ -0,0 +1,396 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/istft.hpp" + +#include "base_reference_test.hpp" +#include "gtest/gtest.h" +#include "openvino/op/parameter.hpp" + +namespace { +using ov::Shape; +struct ISTFTParams { + ISTFTParams(const reference_tests::Tensor& signal, + const reference_tests::Tensor& window, + const reference_tests::Tensor& frame_size, + const reference_tests::Tensor& frame_step, + const reference_tests::Tensor& length, + bool center, + bool normalized, + const reference_tests::Tensor& expected_tensor, + std::string name) + : signal{signal}, + window{window}, + frame_size{frame_size}, + frame_step{frame_step}, + length{length}, + center{center}, + normalized{normalized}, + expected_tensor(expected_tensor), + test_case_name{std::move(name)} {} + + reference_tests::Tensor signal; + reference_tests::Tensor window; + reference_tests::Tensor frame_size; + reference_tests::Tensor frame_step; + reference_tests::Tensor length; + + bool center; + bool normalized; + + reference_tests::Tensor expected_tensor; + std::string test_case_name; +}; + +class ReferenceISTFT : public testing::TestWithParam, public reference_tests::CommonReferenceTest { +public: + void SetUp() override { + const auto& params = GetParam(); + function = CreateFunction(params); + inputData = {params.signal.data, + params.window.data, + params.frame_size.data, + params.frame_step.data, + params.length.data}; + refOutData = {params.expected_tensor.data}; + abs_threshold = 1e-5f; + } + + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::ostringstream name; + name << obj.param.test_case_name; + name << "_signal_input_type_"; + name << obj.param.signal.type; + name << "_signal_shape_"; + name << obj.param.signal.shape; + name << "_window_input_type_"; + name << obj.param.window.type; + name << "_window_shape_"; + name << obj.param.window.shape; + name << "_frame_size_input_type_"; + name << obj.param.frame_size.type; + name << "_frame_size_shape_"; + name << obj.param.frame_size.shape; + name << "_frame_step_input_type_"; + name << obj.param.frame_step.type; + name << "_frame_step_shape_"; + name << obj.param.frame_step.shape; + name << "_lentgh_input_type_"; + name << obj.param.frame_step.type; + name << "_lebgth_shape_"; + name << obj.param.frame_step.shape; + name << "_center_"; + name << obj.param.center; + name << "_normalized_"; + name << obj.param.normalized; + return name.str(); + } + +private: + static std::shared_ptr CreateFunction(const ISTFTParams& params) { + const auto in_signal = std::make_shared(params.signal.type, params.signal.shape); + const auto in_window = std::make_shared(params.window.type, params.window.shape); + const auto in_frame_size = + std::make_shared(params.frame_size.type, params.frame_size.shape); + const auto in_frame_step = + std::make_shared(params.frame_step.type, params.frame_step.shape); + const auto in_length = std::make_shared(params.length.type, params.length.shape); + + const auto ISTFT = std::make_shared(in_signal, + in_window, + in_frame_size, + in_frame_step, + in_length, + params.center, + params.normalized); + return std::make_shared( + ISTFT->outputs(), + ov::ParameterVector{in_signal, in_window, in_frame_size, in_frame_step, in_length}); + } +}; + +template +std::vector generateISTFTParams() { + using VT = typename ov::element_type_traits::value_type; + using INT_T = typename ov::element_type_traits::value_type; + + const ov::Shape signal_48_shape{48}; + const ov::Shape signal_1_48_shape{1, 48}; + const ov::Shape signal_2_48_shape{2, 48}; + const ov::Shape signal_256_shape{1, 256}; + + const ov::Shape signal_16_shape{16}; + + reference_tests::Tensor signal_16(signal_16_shape, + ET, + std::vector{5.8779e-01, + -7.7051e-01, + -6.8455e-01, + 6.8455e-01, + 7.7051e-01, + -5.8779e-01, + -8.4433e-01, + 4.8175e-01, + 9.0483e-01, + -3.6813e-01, + -9.5106e-01, + 2.4869e-01, + 9.8229e-01, + -1.2533e-01, + -9.9803e-01, + 6.8545e-07}); + + reference_tests::Tensor signal_48( + signal_48_shape, + ET, + std::vector{-0.9511, -0.1861, 0.7722, 0.9283, 0.1200, -0.8129, -0.9014, -0.0534, 0.8500, 0.8704, + -0.0134, -0.8833, -0.8356, 0.0801, 0.9126, 0.7971, -0.1465, -0.9379, -0.7550, 0.2123, + 0.9590, 0.7095, -0.2771, -0.9758, -0.6608, 0.3406, 0.9882, 0.6092, -0.4027, -0.9962, + -0.5549, 0.4629, 0.9998, 0.4981, -0.5211, -0.9989, -0.4390, 0.5769, 0.9936, 0.3780, + -0.6302, -0.9838, -0.3154, 0.6806, 0.9696, 0.2513, -0.7281, -0.9511}); + + reference_tests::Tensor signal_1_48( + signal_1_48_shape, + ET, + std::vector{-0.9511, -0.1861, 0.7722, 0.9283, 0.1200, -0.8129, -0.9014, -0.0534, 0.8500, 0.8704, + -0.0134, -0.8833, -0.8356, 0.0801, 0.9126, 0.7971, -0.1465, -0.9379, -0.7550, 0.2123, + 0.9590, 0.7095, -0.2771, -0.9758, -0.6608, 0.3406, 0.9882, 0.6092, -0.4027, -0.9962, + -0.5549, 0.4629, 0.9998, 0.4981, -0.5211, -0.9989, -0.4390, 0.5769, 0.9936, 0.3780, + -0.6302, -0.9838, -0.3154, 0.6806, 0.9696, 0.2513, -0.7281, -0.9511}); + + reference_tests::Tensor signal_2_48( + signal_2_48_shape, + ET, + std::vector{ + -0.9511, -0.1861, 0.7722, 0.9283, 0.1200, -0.8129, -0.9014, -0.0534, 0.8500, 0.8704, -0.0134, -0.8833, + -0.8356, 0.0801, 0.9126, 0.7971, -0.1465, -0.9379, -0.7550, 0.2123, 0.9590, 0.7095, -0.2771, -0.9758, + -0.6608, 0.3406, 0.9882, 0.6092, -0.4027, -0.9962, -0.5549, 0.4629, 0.9998, 0.4981, -0.5211, -0.9989, + -0.4390, 0.5769, 0.9936, 0.3780, -0.6302, -0.9838, -0.3154, 0.6806, 0.9696, 0.2513, -0.7281, -0.9511, + -0.9511, -0.1861, 0.7722, 0.9283, 0.1200, -0.8129, -0.9014, -0.0534, 0.8500, 0.8704, -0.0134, -0.8833, + -0.8356, 0.0801, 0.9126, 0.7971, -0.1465, -0.9379, -0.7550, 0.2123, 0.9590, 0.7095, -0.2771, -0.9758, + -0.6608, 0.3406, 0.9882, 0.6092, -0.4027, -0.9962, -0.5549, 0.4629, 0.9998, 0.4981, -0.5211, -0.9989, + -0.4390, 0.5769, 0.9936, 0.3780, -0.6302, -0.9838, -0.3154, 0.6806, 0.9696, 0.2513, -0.7281, -0.9511}); + + reference_tests::Tensor ones_window_16(Shape{16}, ET, std::vector(16, 1.f)); + reference_tests::Tensor two_window_16(Shape{16}, ET, std::vector(16, 2.f)); + + reference_tests::Tensor hann_window_5(Shape{5}, ET, std::vector{0., 0.5, 1., 0.5, 0.}); + reference_tests::Tensor hann_window_7(Shape{7}, ET, std::vector{0., 0.25, 0.75, 1., 0.75, 0.25, 0.}); + reference_tests::Tensor hann_window_8( + Shape{8}, + ET, + std::vector{0., 0.18826, 0.61126, 0.95048, 0.95048, 0.61126, 0.18826, 0.}); + reference_tests::Tensor hann_window_10( + Shape{10}, + ET, + std::vector{0., 0.11698, 0.41318, 0.75, 0.96985, 0.96985, 0.75, 0.41318, 0.11698, 0.}); + reference_tests::Tensor hann_window_16(Shape{16}, + ET, + std::vector{0., + 0.04323, + 0.16543, + 0.34549, + 0.55226, + 0.75, + 0.90451, + 0.98907, + 0.98907, + 0.90451, + 0.75, + 0.55226, + 0.34549, + 0.16543, + 0.04323, + 0.}); + + reference_tests::Tensor frame_size_8(Shape{}, IT, std::vector{8}); + reference_tests::Tensor frame_size_9(Shape{}, IT, std::vector{9}); + reference_tests::Tensor frame_size_11(Shape{}, IT, std::vector{11}); + reference_tests::Tensor frame_size_16(Shape{}, IT, std::vector{16}); + + reference_tests::Tensor frame_step_2(Shape{}, IT, std::vector{2}); + reference_tests::Tensor frame_step_3(Shape{}, IT, std::vector{3}); + reference_tests::Tensor frame_step_4(Shape{}, IT, std::vector{4}); + reference_tests::Tensor frame_step_8(Shape{}, IT, std::vector{8}); + reference_tests::Tensor frame_step_16(Shape{}, IT, std::vector{16}); + reference_tests::Tensor frame_step_100(Shape{}, IT, std::vector{100}); + + reference_tests::Tensor output_stft_9_1_2_transp_win_one(Shape{9, 1, 2}, + ET, + std::vector{-0.6693, + 0.0000, + -0.7103, + -0.0912, + -0.8803, + -0.2251, + -1.5651, + -0.5924, + 6.7234, + 3.2667, + 0.7715, + 0.4254, + 0.3599, + 0.1884, + 0.2358, + 0.0796, + 0.2042, + 0.0000}); + + reference_tests::Tensor output_stft_9_1_2_transp_win_two(Shape{9, 1, 2}, + ET, + std::vector{-1.3386, + 0.0000, + -1.4207, + -0.1823, + -1.7606, + -0.4502, + -3.1302, + -1.1848, + 13.4467, + 6.5335, + 1.5429, + 0.8508, + 0.7199, + 0.3768, + 0.4716, + 0.1591, + 0.4084, + 0.0000}); + + reference_tests::Tensor output_stft_9_3_2_transp_win_two( + Shape{9, 3, 2}, + ET, + std::vector{1.3873, 0.0000, -2.8503, 0.0000, -0.4391, 0.0000, 1.7637, -0.6945, -3.1429, + -0.9896, -0.7182, 1.0237, 4.2213, -2.5114, -5.0535, -3.5783, -2.5402, 3.7017, + -12.4337, 7.5925, 7.8944, 10.8180, 9.8076, -11.1912, -3.1735, 1.6741, 0.6953, + 2.3853, 2.9422, -2.4676, -2.1233, 0.8610, -0.1211, 1.2268, 2.1636, -1.2691, + -1.7631, 0.4790, -0.4011, 0.6825, 1.8965, -0.7060, -1.6151, 0.2192, -0.5161, + 0.3123, 1.7869, -0.3231, -1.5735, 0.0000, -0.5485, 0.0000, 1.7559, 0.0000}); + + reference_tests::Tensor output_stft_9_4_2_transp_win_two_center( + Shape{9, 4, 2}, + ET, + std::vector{ + -7.3526e-01, 0.0000e+00, 1.1330e+00, 0.0000e+00, 2.5475e+00, 0.0000e+00, -4.1695e+00, 0.0000e+00, + -3.3068e+00, 2.3842e-07, 1.0681e+00, 1.3042e+00, 2.9902e+00, -2.6441e-02, -2.2545e+00, -9.3384e-01, + -1.6860e+00, 0.0000e+00, 6.4448e-01, 4.7161e+00, 5.8809e+00, -9.5572e-02, -6.7604e+00, -6.7604e+00, + 1.4974e+01, 0.0000e+00, 3.5154e+00, -1.4258e+01, -1.3709e+01, 2.8896e-01, 4.2285e+00, 1.0209e+01, + 7.9468e-01, 0.0000e+00, 1.9192e+00, -3.1437e+00, -2.8170e+00, 6.3723e-02, 0.0000e+00, 4.5065e+00, + 1.6981e+00, 0.0000e+00, 1.7382e+00, -1.6169e+00, -1.5818e+00, 3.2759e-02, -4.7954e-01, 1.1577e+00, + 3.2156e-01, 0.0000e+00, 1.6761e+00, -8.9949e-01, -1.1581e+00, 1.8237e-02, -1.2894e+00, 1.2894e+00, + 1.0437e+00, 2.3842e-07, 1.6506e+00, -4.1167e-01, -9.8408e-01, 8.3490e-03, -7.1159e-01, 2.9475e-01, + 2.5795e-01, 0.0000e+00, 1.6434e+00, 0.0000e+00, -9.3507e-01, 0.0000e+00, -1.4628e+00, 0.0000e+00}); + + reference_tests::Tensor output_stft_2_9_4_2_transp_win_two_center( + Shape{2, 9, 4, 2}, + ET, + std::vector{ + -7.3526e-01, 0.0000e+00, 1.1330e+00, 0.0000e+00, 2.5475e+00, 0.0000e+00, -4.1695e+00, 0.0000e+00, + -3.3068e+00, 2.3842e-07, 1.0681e+00, 1.3042e+00, 2.9902e+00, -2.6441e-02, -2.2545e+00, -9.3384e-01, + -1.6860e+00, 0.0000e+00, 6.4448e-01, 4.7161e+00, 5.8809e+00, -9.5572e-02, -6.7604e+00, -6.7604e+00, + 1.4974e+01, 0.0000e+00, 3.5154e+00, -1.4258e+01, -1.3709e+01, 2.8896e-01, 4.2285e+00, 1.0209e+01, + 7.9468e-01, 0.0000e+00, 1.9192e+00, -3.1437e+00, -2.8170e+00, 6.3723e-02, 0.0000e+00, 4.5065e+00, + 1.6981e+00, 0.0000e+00, 1.7382e+00, -1.6169e+00, -1.5818e+00, 3.2759e-02, -4.7954e-01, 1.1577e+00, + 3.2156e-01, 0.0000e+00, 1.6761e+00, -8.9949e-01, -1.1581e+00, 1.8237e-02, -1.2894e+00, 1.2894e+00, + 1.0437e+00, 2.3842e-07, 1.6506e+00, -4.1167e-01, -9.8408e-01, 8.3490e-03, -7.1159e-01, 2.9475e-01, + 2.5795e-01, 0.0000e+00, 1.6434e+00, 0.0000e+00, -9.3507e-01, 0.0000e+00, -1.4628e+00, 0.0000e+00, + -7.3526e-01, 0.0000e+00, 1.1330e+00, 0.0000e+00, 2.5475e+00, 0.0000e+00, -4.1695e+00, 0.0000e+00, + -3.3068e+00, 2.3842e-07, 1.0681e+00, 1.3042e+00, 2.9902e+00, -2.6441e-02, -2.2545e+00, -9.3384e-01, + -1.6860e+00, 0.0000e+00, 6.4448e-01, 4.7161e+00, 5.8809e+00, -9.5572e-02, -6.7604e+00, -6.7604e+00, + 1.4974e+01, 0.0000e+00, 3.5154e+00, -1.4258e+01, -1.3709e+01, 2.8896e-01, 4.2285e+00, 1.0209e+01, + 7.9468e-01, 0.0000e+00, 1.9192e+00, -3.1437e+00, -2.8170e+00, 6.3723e-02, 0.0000e+00, 4.5065e+00, + 1.6981e+00, 0.0000e+00, 1.7382e+00, -1.6169e+00, -1.5818e+00, 3.2759e-02, -4.7954e-01, 1.1577e+00, + 3.2156e-01, 0.0000e+00, 1.6761e+00, -8.9949e-01, -1.1581e+00, 1.8237e-02, -1.2894e+00, 1.2894e+00, + 1.0437e+00, 2.3842e-07, 1.6506e+00, -4.1167e-01, -9.8408e-01, 8.3490e-03, -7.1159e-01, 2.9475e-01, + 2.5795e-01, 0.0000e+00, 1.6434e+00, 0.0000e+00, -9.3507e-01, 0.0000e+00, -1.4628e+00, 0.0000e+00}); + + reference_tests::Tensor auto_length(Shape{}, IT, std::vector{-1}); + reference_tests::Tensor length_16(Shape{}, IT, std::vector{16}); + reference_tests::Tensor length_48(Shape{}, IT, std::vector{48}); + + std::vector params; + params.emplace_back(output_stft_9_1_2_transp_win_one, + ones_window_16, + frame_size_16, + frame_step_16, + auto_length, + false, + false, + signal_16, + "basic_1D_transp_ones_win_step_16"); + params.emplace_back(output_stft_9_1_2_transp_win_one, + ones_window_16, + frame_size_16, + frame_step_4, + auto_length, + false, + false, + signal_16, + "basic_1D_transp_ones_win_step_4"); + params.emplace_back(output_stft_9_1_2_transp_win_two, + two_window_16, + frame_size_16, + frame_step_16, + auto_length, + false, + false, + signal_16, + "basic_1D_transp_two_win_step_16"); + params.emplace_back(output_stft_9_1_2_transp_win_two, + two_window_16, + frame_size_16, + frame_step_4, + auto_length, + false, + false, + signal_16, + "basic_1D_transp_two_win_step_4"); + params.emplace_back(output_stft_9_3_2_transp_win_two, + two_window_16, + frame_size_16, + frame_step_16, + auto_length, + false, + false, + signal_48, + "basic_1D_transp_two_win_step_16"); + params.emplace_back(output_stft_9_4_2_transp_win_two_center, + two_window_16, + frame_size_16, + frame_step_16, + auto_length, + true, + false, + signal_48, + "basic_1D_transp_two_win_step_16_center"); + params.emplace_back(output_stft_2_9_4_2_transp_win_two_center, + two_window_16, + frame_size_16, + frame_step_16, + auto_length, + true, + false, + signal_2_48, + "basic_2D_transp_two_win_step_16_center"); + + return params; +} + +std::vector generateISTFTParams() { + std::vector> combo_params{generateISTFTParams()}; + std::vector test_params; + for (auto& params : combo_params) + std::move(params.begin(), params.end(), std::back_inserter(test_params)); + return test_params; +} +} // namespace + +TEST_P(ReferenceISTFT, CompareWithRefs) { + Exec(); +} + +INSTANTIATE_TEST_SUITE_P(smoke, + ReferenceISTFT, + ::testing::ValuesIn(generateISTFTParams()), + ReferenceISTFT::getTestCaseName); From 92526c6f095ffb22b530dc584d998f80fe2d6362 Mon Sep 17 00:00:00 2001 From: mitruska Date: Fri, 24 Jan 2025 08:34:53 +0100 Subject: [PATCH 3/7] Code cleanup --- src/core/reference/src/op/istft.cpp | 37 +---------------------------- 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/src/core/reference/src/op/istft.cpp b/src/core/reference/src/op/istft.cpp index c7eb144f8535cc..83983c606d9d17 100644 --- a/src/core/reference/src/op/istft.cpp +++ b/src/core/reference/src/op/istft.cpp @@ -37,12 +37,6 @@ void istft(const float* in_data, const auto signal_length = (num_frames - 1) * frame_step + frame_size; const auto final_signal_length = length > 0 ? length : (center ? (signal_length - frame_size) : signal_length); - // auto signal_length = (num_frames - 1) * frame_step; - // if (!center) { - // signal_length += frame_size; - // } - // const auto final_signal_length = length > 0 ? length : signal_length; - std::vector mid_result(batch_size * signal_length, 0); float* result = mid_result.data(); @@ -98,41 +92,12 @@ void istft(const float* in_data, op::AutoBroadcastType::NUMPY); std::transform(window_sum.begin() + out_frame_start, - window_sum.begin() + out_frame_start + frame_size, + window_sum.begin() + out_frame_end, pad_window.begin(), window_sum.begin() + out_frame_start, std::plus()); - - // std::transform(result + out_frame_start, - // result + out_frame_start + frame_size, - // pad_window.begin(), - // result + out_frame_start, - // [](float a, float b) { - // if (b != 0.f) - // return a / b; - // else - // return 0.f; - // }); - - // reference::multiply(result + out_frame_start, - // pad_window.data(), - // result + out_frame_start, - // frame_size_dim_shape, - // frame_size_dim_shape, - // op::AutoBroadcastType::NUMPY); } - // std::transform(result + batch_out_start, - // result + batch_out_start + signal_length, - // window_sum.begin(), - // result + batch_out_start, - // [](float a, float b) { - // if (b != 0.f) - // return a / (b*b); - // else - // return 0.f; - // }); - std::transform(result + batch_out_start, result + batch_out_start + signal_length, window_sum.begin(), From b57762b3e17f8618004adc979621b87d20d904f9 Mon Sep 17 00:00:00 2001 From: mitruska Date: Fri, 24 Jan 2025 14:43:39 +0100 Subject: [PATCH 4/7] Add normalization --- src/core/reference/src/op/istft.cpp | 11 ++++- .../tests/functional/op_reference/istft.cpp | 47 +++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/src/core/reference/src/op/istft.cpp b/src/core/reference/src/op/istft.cpp index 83983c606d9d17..4764aa2c0591dd 100644 --- a/src/core/reference/src/op/istft.cpp +++ b/src/core/reference/src/op/istft.cpp @@ -32,6 +32,7 @@ void istft(const float* in_data, const size_t frames_axis = 1 + (is_data_3D ? 0 : 1); const size_t batch_size = is_data_3D ? 1 : data_shape[0]; + const auto sqrt_frame_size = std::sqrt(frame_size); const auto num_frames = data_shape[frames_axis]; const auto signal_length = (num_frames - 1) * frame_step + frame_size; @@ -62,7 +63,6 @@ void istft(const float* in_data, } const auto fft_out_shape_size = shape_size(fft_out_shape); - std::vector window_sum(batch_size * signal_length); for (size_t batch = 0, batch_in_start = 0, batch_out_start = 0; batch < batch_size; ++batch) { @@ -98,6 +98,15 @@ void istft(const float* in_data, std::plus()); } + if (normalized) { + std::transform(result + batch_out_start, + result + batch_out_start + signal_length, + result + batch_out_start, + [sqrt_frame_size](float a) { + return a * sqrt_frame_size; + }); + } + std::transform(result + batch_out_start, result + batch_out_start + signal_length, window_sum.begin(), diff --git a/src/plugins/template/tests/functional/op_reference/istft.cpp b/src/plugins/template/tests/functional/op_reference/istft.cpp index 35f198c04032cd..63ad83e19ebca7 100644 --- a/src/plugins/template/tests/functional/op_reference/istft.cpp +++ b/src/plugins/template/tests/functional/op_reference/istft.cpp @@ -268,6 +268,21 @@ std::vector generateISTFTParams() { -1.7631, 0.4790, -0.4011, 0.6825, 1.8965, -0.7060, -1.6151, 0.2192, -0.5161, 0.3123, 1.7869, -0.3231, -1.5735, 0.0000, -0.5485, 0.0000, 1.7559, 0.0000}); + reference_tests::Tensor output_stft_2_9_3_2_transp_win_two( + Shape{2, 9, 3, 2}, + ET, + std::vector{1.3873, 0.0000, -2.8503, 0.0000, -0.4391, 0.0000, 1.7637, -0.6945, -3.1429, -0.9896, + -0.7182, 1.0237, 4.2213, -2.5114, -5.0535, -3.5783, -2.5402, 3.7017, -12.4337, 7.5925, + 7.8944, 10.8180, 9.8076, -11.1912, -3.1735, 1.6741, 0.6953, 2.3853, 2.9422, -2.4676, + -2.1233, 0.8610, -0.1211, 1.2268, 2.1636, -1.2691, -1.7631, 0.4790, -0.4011, 0.6825, + 1.8965, -0.7060, -1.6151, 0.2192, -0.5161, 0.3123, 1.7869, -0.3231, -1.5735, 0.0000, + -0.5485, 0.0000, 1.7559, 0.0000, 1.3873, 0.0000, -2.8503, 0.0000, -0.4391, 0.0000, + 1.7637, -0.6945, -3.1429, -0.9896, -0.7182, 1.0237, 4.2213, -2.5114, -5.0535, -3.5783, + -2.5402, 3.7017, -12.4337, 7.5925, 7.8944, 10.8180, 9.8076, -11.1912, -3.1735, 1.6741, + 0.6953, 2.3853, 2.9422, -2.4676, -2.1233, 0.8610, -0.1211, 1.2268, 2.1636, -1.2691, + -1.7631, 0.4790, -0.4011, 0.6825, 1.8965, -0.7060, -1.6151, 0.2192, -0.5161, 0.3123, + 1.7869, -0.3231, -1.5735, 0.0000, -0.5485, 0.0000, 1.7559, 0.0000}); + reference_tests::Tensor output_stft_9_4_2_transp_win_two_center( Shape{9, 4, 2}, ET, @@ -305,6 +320,20 @@ std::vector generateISTFTParams() { 1.0437e+00, 2.3842e-07, 1.6506e+00, -4.1167e-01, -9.8408e-01, 8.3490e-03, -7.1159e-01, 2.9475e-01, 2.5795e-01, 0.0000e+00, 1.6434e+00, 0.0000e+00, -9.3507e-01, 0.0000e+00, -1.4628e+00, 0.0000e+00}); + reference_tests::Tensor output_stft_9_4_2_transp_win_two_center_norm( + Shape{9, 4, 2}, + ET, + std::vector{ + -1.8382e-01, 0.0000e+00, 2.8325e-01, 0.0000e+00, 6.3686e-01, 0.0000e+00, -1.0424e+00, 0.0000e+00, + -8.2669e-01, 5.9605e-08, 2.6703e-01, 3.2606e-01, 7.4754e-01, -6.6102e-03, -5.6362e-01, -2.3346e-01, + -4.2149e-01, 0.0000e+00, 1.6112e-01, 1.1790e+00, 1.4702e+00, -2.3893e-02, -1.6901e+00, -1.6901e+00, + 3.7434e+00, 0.0000e+00, 8.7886e-01, -3.5645e+00, -3.4273e+00, 7.2241e-02, 1.0571e+00, 2.5522e+00, + 1.9867e-01, 0.0000e+00, 4.7980e-01, -7.8593e-01, -7.0426e-01, 1.5931e-02, 0.0000e+00, 1.1266e+00, + 4.2452e-01, 0.0000e+00, 4.3454e-01, -4.0423e-01, -3.9546e-01, 8.1899e-03, -1.1988e-01, 2.8943e-01, + 8.0389e-02, 0.0000e+00, 4.1901e-01, -2.2487e-01, -2.8953e-01, 4.5592e-03, -3.2235e-01, 3.2235e-01, + 2.6093e-01, 5.9605e-08, 4.1264e-01, -1.0292e-01, -2.4602e-01, 2.0872e-03, -1.7790e-01, 7.3688e-02, + 6.4488e-02, 0.0000e+00, 4.1084e-01, 0.0000e+00, -2.3377e-01, 0.0000e+00, -3.6569e-01, 0.0000e+00}); + reference_tests::Tensor auto_length(Shape{}, IT, std::vector{-1}); reference_tests::Tensor length_16(Shape{}, IT, std::vector{16}); reference_tests::Tensor length_48(Shape{}, IT, std::vector{48}); @@ -355,6 +384,15 @@ std::vector generateISTFTParams() { false, signal_48, "basic_1D_transp_two_win_step_16"); + params.emplace_back(output_stft_2_9_3_2_transp_win_two, + two_window_16, + frame_size_16, + frame_step_16, + auto_length, + false, + false, + signal_2_48, + "basic_2D_transp_two_win_step_16"); params.emplace_back(output_stft_9_4_2_transp_win_two_center, two_window_16, frame_size_16, @@ -364,6 +402,15 @@ std::vector generateISTFTParams() { false, signal_48, "basic_1D_transp_two_win_step_16_center"); + params.emplace_back(output_stft_9_4_2_transp_win_two_center_norm, + two_window_16, + frame_size_16, + frame_step_16, + auto_length, + true, + true, + signal_48, + "basic_1D_transp_two_win_step_16_center_norm"); params.emplace_back(output_stft_2_9_4_2_transp_win_two_center, two_window_16, frame_size_16, From 977fe0ae1f47397e26fc35f911918683738433b2 Mon Sep 17 00:00:00 2001 From: mitruska Date: Fri, 24 Jan 2025 16:57:06 +0100 Subject: [PATCH 5/7] Enable signal length input --- .../include/istft_shape_inference.hpp | 96 ++++++++++--------- .../tests/functional/op_reference/istft.cpp | 19 ++++ 2 files changed, 70 insertions(+), 45 deletions(-) diff --git a/src/core/shape_inference/include/istft_shape_inference.hpp b/src/core/shape_inference/include/istft_shape_inference.hpp index d62ea60116b292..d68a8fd7cb0016 100644 --- a/src/core/shape_inference/include/istft_shape_inference.hpp +++ b/src/core/shape_inference/include/istft_shape_inference.hpp @@ -48,63 +48,69 @@ std::vector shape_infer(const ISTFT* op, length_shape.rank().compatible(0), "The shape of length input must be a scalar."); - if (data_shape_rank.is_dynamic()) { - return {data_shape}; - } - const auto frame_size = get_input_const_data_as(op, 2, ta); const auto frame_step = get_input_const_data_as(op, 3, ta); - - const auto is_data_3D = data_shape.size() == 3; - if (!frame_size || !frame_step) { - if (is_data_3D) { - return {TRShape{TDim(ov::util::dim::inf_bound)}}; - } else { - return {TRShape{data_shape[0], TDim(ov::util::dim::inf_bound)}}; - } + const auto sig_len_in = get_input_const_data_as_shape(op, 4, ta); + + if (frame_size) { + const auto& frame_size_val = (*frame_size)[0]; + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + 0 < frame_size_val, + "Provided frame size is ", + frame_size_val, + " but must be greater than zero."); + const bool is_win_shape_correct = + window_shape.is_dynamic() || (TDimVal{0} < window_shape[0].get_length() && + window_shape[0].get_length() <= static_cast(frame_size_val)); + + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + is_win_shape_correct, + "Window input dimension must be in range [1, ", + frame_size_val, + "]."); } - const auto& frame_size_val = (*frame_size)[0]; - const auto& frame_step_val = (*frame_step)[0]; - - NODE_SHAPE_INFER_CHECK(op, - input_shapes, - 0 < frame_size_val, - "Provided frame size is ", - frame_size_val, - " but must be greater than zero."); - - NODE_SHAPE_INFER_CHECK(op, - input_shapes, - 0 < frame_step_val, - "Provided frame step is ", - frame_step_val, - " but must be greater than zero."); - - const bool is_win_shape_correct = - window_shape.is_dynamic() || (TDimVal{0} < window_shape[0].get_length() && - window_shape[0].get_length() <= static_cast(frame_size_val)); - NODE_SHAPE_INFER_CHECK(op, - input_shapes, - is_win_shape_correct, - "Window input dimension must be in range [1, ", - frame_size_val, - "]."); + if (frame_step) { + const auto& frame_step_val = (*frame_step)[0]; + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + 0 < frame_step_val, + "Provided frame step is ", + frame_step_val, + " but must be greater than zero."); + } - int64_t frames_axis = 1 + (is_data_3D ? 0 : 1); - const TDim& num_frames_dim = data_shape[frames_axis]; - TDim signal_length = (num_frames_dim - 1) * frame_step_val; - if (!op->get_center()) { - signal_length += frame_size_val; + // For the input with dynamic rank, output shape is also fully dynamic + if (data_shape_rank.is_dynamic()) { + return {data_shape}; } + const auto is_data_3D = data_shape.size() == 3; std::vector output_shapes; - output_shapes.emplace_back(TRShape{std::move(signal_length)}); + if (sig_len_in && (*sig_len_in)[0].is_static()) { // Set desired length of the signal dimension, if provided + output_shapes.emplace_back(TRShape{(*sig_len_in)[0]}); + } else if (frame_size && frame_step) { // Otherwise infer the length of the signal + const auto& frame_size_val = (*frame_size)[0]; + const auto& frame_step_val = (*frame_step)[0]; + + const int64_t frames_axis = 1 + (is_data_3D ? 0 : 1); + const TDim& num_frames_dim = data_shape[frames_axis]; + TDim signal_length = (num_frames_dim - 1) * frame_step_val; + if (!op->get_center()) { + signal_length += frame_size_val; + } + output_shapes.emplace_back(TRShape{std::move(signal_length)}); + } else { // Not enough info to infer the signal lenght, set dynamic dimension + output_shapes.emplace_back(TRShape{TDim(ov::util::dim::inf_bound)}); + } - if (!is_data_3D) { + if (!is_data_3D) { // Copy batch dimension const auto& batch_dim = data_shape[0]; output_shapes[0].insert(output_shapes[0].begin(), batch_dim); } + return output_shapes; } } // namespace v16 diff --git a/src/plugins/template/tests/functional/op_reference/istft.cpp b/src/plugins/template/tests/functional/op_reference/istft.cpp index 63ad83e19ebca7..6937337a109e1c 100644 --- a/src/plugins/template/tests/functional/op_reference/istft.cpp +++ b/src/plugins/template/tests/functional/op_reference/istft.cpp @@ -116,6 +116,7 @@ std::vector generateISTFTParams() { using INT_T = typename ov::element_type_traits::value_type; const ov::Shape signal_48_shape{48}; + const ov::Shape signal_39_shape{39}; const ov::Shape signal_1_48_shape{1, 48}; const ov::Shape signal_2_48_shape{2, 48}; const ov::Shape signal_256_shape{1, 256}; @@ -150,6 +151,14 @@ std::vector generateISTFTParams() { -0.5549, 0.4629, 0.9998, 0.4981, -0.5211, -0.9989, -0.4390, 0.5769, 0.9936, 0.3780, -0.6302, -0.9838, -0.3154, 0.6806, 0.9696, 0.2513, -0.7281, -0.9511}); + reference_tests::Tensor signal_39( + signal_39_shape, + ET, + std::vector{-0.9511, -0.1861, 0.7722, 0.9283, 0.1200, -0.8129, -0.9014, -0.0534, 0.8500, 0.8704, + -0.0134, -0.8833, -0.8356, 0.0801, 0.9126, 0.7971, -0.1465, -0.9379, -0.7550, 0.2123, + 0.9590, 0.7095, -0.2771, -0.9758, -0.6608, 0.3406, 0.9882, 0.6092, -0.4027, -0.9962, + -0.5549, 0.4629, 0.9998, 0.4981, -0.5211, -0.9989, -0.4390, 0.5769, 0.9936}); + reference_tests::Tensor signal_1_48( signal_1_48_shape, ET, @@ -336,6 +345,7 @@ std::vector generateISTFTParams() { reference_tests::Tensor auto_length(Shape{}, IT, std::vector{-1}); reference_tests::Tensor length_16(Shape{}, IT, std::vector{16}); + reference_tests::Tensor length_39(Shape{}, IT, std::vector{39}); reference_tests::Tensor length_48(Shape{}, IT, std::vector{48}); std::vector params; @@ -420,6 +430,15 @@ std::vector generateISTFTParams() { false, signal_2_48, "basic_2D_transp_two_win_step_16_center"); + params.emplace_back(output_stft_9_4_2_transp_win_two_center_norm, + two_window_16, + frame_size_16, + frame_step_16, + length_39, + true, + true, + signal_39, + "basic_1D_transp_two_win_step_16_center_norm_length_39"); return params; } From 13faa4d2b46e05baa97fcd0912c2ee0bed6f3f59 Mon Sep 17 00:00:00 2001 From: mitruska Date: Fri, 24 Jan 2025 19:03:51 +0100 Subject: [PATCH 6/7] Make signal_length optional input --- src/core/include/openvino/op/istft.hpp | 21 ++++++- .../include/istft_shape_inference.hpp | 24 +++++--- src/core/src/op/istft.cpp | 53 ++++++++++++----- .../tests/functional/op_reference/istft.cpp | 58 ++++++++++++++----- 4 files changed, 113 insertions(+), 43 deletions(-) diff --git a/src/core/include/openvino/op/istft.hpp b/src/core/include/openvino/op/istft.hpp index 611ee10316086b..89f98740370e5b 100644 --- a/src/core/include/openvino/op/istft.hpp +++ b/src/core/include/openvino/op/istft.hpp @@ -16,20 +16,35 @@ class OPENVINO_API ISTFT : public Op { OPENVINO_OP("ISTFT", "opset16"); ISTFT() = default; - /// \brief Constructs an ISTFT operation. + /// \brief Constructs an ISTFT operation with signal length to be inferred /// /// \param data Input data /// \param window Window values applied in STFT /// \param frame_size Scalar value representing the size of Fourier Transform /// \param frame_step The distance (number of samples) between successive window frames - /// \param length The length of the original signal /// \param center Flag signaling if the signal input has been padded before STFT /// \param normalized Flag signaling if the STFT result has been normalized. ISTFT(const Output& data, const Output& window, const Output& frame_size, const Output& frame_step, - const Output& length, + const bool center, + const bool normalized); + + /// \brief Constructs an ISTFT operation with signal length provided + /// + /// \param data Input data + /// \param window Window values applied in STFT + /// \param frame_size Scalar value representing the size of Fourier Transform + /// \param frame_step The distance (number of samples) between successive window frames + /// \param signal_length The signal length of the original signal + /// \param center Flag signaling if the signal input has been padded before STFT + /// \param normalized Flag signaling if the STFT result has been normalized. + ISTFT(const Output& data, + const Output& window, + const Output& frame_size, + const Output& frame_step, + const Output& signal_length, const bool center, const bool normalized); diff --git a/src/core/shape_inference/include/istft_shape_inference.hpp b/src/core/shape_inference/include/istft_shape_inference.hpp index d68a8fd7cb0016..bd1bf5bacaaf96 100644 --- a/src/core/shape_inference/include/istft_shape_inference.hpp +++ b/src/core/shape_inference/include/istft_shape_inference.hpp @@ -18,13 +18,14 @@ std::vector shape_infer(const ISTFT* op, using TDim = typename TRShape::value_type; using TDimVal = typename TDim::value_type; - NODE_VALIDATION_CHECK(op, input_shapes.size() == 5); + const auto inputs_count = input_shapes.size(); + const auto is_in_count_correct = inputs_count == 4 || inputs_count == 5; + NODE_VALIDATION_CHECK(op, is_in_count_correct); const auto& data_shape = input_shapes[0]; const auto& window_shape = input_shapes[1]; const auto& frame_size_shape = input_shapes[2]; const auto& frame_step_shape = input_shapes[3]; - const auto& length_shape = input_shapes[4]; const auto data_shape_rank = data_shape.rank(); NODE_SHAPE_INFER_CHECK(op, @@ -43,14 +44,9 @@ std::vector shape_infer(const ISTFT* op, input_shapes, frame_step_shape.rank().compatible(0), "The shape of frame_step must be a scalar."); - NODE_SHAPE_INFER_CHECK(op, - input_shapes, - length_shape.rank().compatible(0), - "The shape of length input must be a scalar."); const auto frame_size = get_input_const_data_as(op, 2, ta); const auto frame_step = get_input_const_data_as(op, 3, ta); - const auto sig_len_in = get_input_const_data_as_shape(op, 4, ta); if (frame_size) { const auto& frame_size_val = (*frame_size)[0]; @@ -89,8 +85,18 @@ std::vector shape_infer(const ISTFT* op, const auto is_data_3D = data_shape.size() == 3; std::vector output_shapes; - if (sig_len_in && (*sig_len_in)[0].is_static()) { // Set desired length of the signal dimension, if provided - output_shapes.emplace_back(TRShape{(*sig_len_in)[0]}); + if (inputs_count == 5) { + const auto& length_shape = input_shapes[4]; + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + length_shape.rank().compatible(0), + "The shape of length input must be a scalar."); + const auto sig_len_in = get_input_const_data_as_shape(op, 4, ta); + if (sig_len_in) { // Set desired length of the signal dimension, if provided + output_shapes.emplace_back(TRShape{(*sig_len_in)[0]}); + } else { + output_shapes.emplace_back(TRShape{TDim(ov::util::dim::inf_bound)}); + } } else if (frame_size && frame_step) { // Otherwise infer the length of the signal const auto& frame_size_val = (*frame_size)[0]; const auto& frame_step_val = (*frame_step)[0]; diff --git a/src/core/src/op/istft.cpp b/src/core/src/op/istft.cpp index d6bf60cc3330ce..a8903aef0af631 100644 --- a/src/core/src/op/istft.cpp +++ b/src/core/src/op/istft.cpp @@ -24,10 +24,22 @@ ISTFT::ISTFT(const Output& data, const Output& window, const Output& frame_size, const Output& frame_step, - const Output& length, const bool center, const bool normalized) - : Op({data, window, frame_size, frame_step, length}), + : Op({data, window, frame_size, frame_step}), + m_center(center), + m_normalized(normalized) { + constructor_validate_and_infer_types(); +} + +ISTFT::ISTFT(const Output& data, + const Output& window, + const Output& frame_size, + const Output& frame_step, + const Output& signal_length, + const bool center, + const bool normalized) + : Op({data, window, frame_size, frame_step, signal_length}), m_center(center), m_normalized(normalized) { constructor_validate_and_infer_types(); @@ -37,6 +49,14 @@ std::shared_ptr ISTFT::clone_with_new_inputs(const OutputVector& new_args) OV_OP_SCOPE(v16_ISTFT_clone_with_new_inputs); check_new_args_count(this, new_args); + if (new_args.size() == 4) { + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + m_center, + m_normalized); + } return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), @@ -55,35 +75,35 @@ bool ISTFT::visit_attributes(AttributeVisitor& visitor) { void ISTFT::validate_and_infer_types() { OV_OP_SCOPE(v16_ISTFT_validate_and_infer_types); - NODE_VALIDATION_CHECK(this, get_input_size() == 5, "Expected 5 inputs to be provided."); + const auto input_size = get_input_size(); + const auto is_in_count_correct = input_size == 4 || input_size == 5; + NODE_VALIDATION_CHECK(this, is_in_count_correct, "Expected 4 or 5 inputs to be provided."); - auto signal_type = get_input_element_type(0); + auto data_type = get_input_element_type(0); const auto& window_type = get_input_element_type(1); - const auto has_valid_signal_type = signal_type.is_dynamic() || signal_type.is_real(); - NODE_VALIDATION_CHECK(this, has_valid_signal_type, "Expected floating point type of the 'signal' input."); + const auto has_valid_data_type = data_type.is_dynamic() || data_type.is_real(); + NODE_VALIDATION_CHECK(this, has_valid_data_type, "Expected floating point type of the 'data' input."); const auto has_valid_window_type = - window_type.is_dynamic() || - (window_type.is_real() && element::Type::merge(signal_type, window_type, signal_type)); - NODE_VALIDATION_CHECK(this, - has_valid_window_type, - "Expected floating point type of the 'window' input, matching the type of `signal` input."); + window_type.is_dynamic() || (window_type.is_real() && element::Type::merge(data_type, window_type, data_type)); + NODE_VALIDATION_CHECK(this, has_valid_window_type, "Expected floating point type of the 'window' input."); check_int_input_at(this, 2); check_int_input_at(this, 3); - check_int_input_at(this, 4); + if (input_size == 5) { + check_int_input_at(this, 4); + } const auto input_shapes = ov::util::get_node_input_partial_shapes(*this); const auto output_shapes = shape_infer(this, input_shapes); - set_output_type(0, signal_type, output_shapes[0]); + set_output_type(0, data_type, output_shapes[0]); } bool ISTFT::evaluate(TensorVector& outputs, const TensorVector& inputs) const { OV_OP_SCOPE(v16_ISTFT_evaluate); OPENVINO_ASSERT(outputs.size() == 1); - OPENVINO_ASSERT(inputs.size() == 5); const auto input_shapes = ov::util::get_tensors_partial_shapes(inputs); const auto output_shape = shape_infer(this, input_shapes, make_tensor_accessor(inputs)).front().to_shape(); @@ -92,7 +112,10 @@ bool ISTFT::evaluate(TensorVector& outputs, const TensorVector& inputs) const { const auto frame_size = ov::get_tensor_data_as(inputs[2]).front(); const auto frame_step = ov::get_tensor_data_as(inputs[3]).front(); - const auto length = ov::get_tensor_data_as(inputs[4]).front(); + int64_t length = -1; + if (inputs.size() == 5) { + length = ov::get_tensor_data_as(inputs[4]).front(); + } ov::reference::istft(inputs[0].data(), inputs[1].data(), diff --git a/src/plugins/template/tests/functional/op_reference/istft.cpp b/src/plugins/template/tests/functional/op_reference/istft.cpp index 6937337a109e1c..933721a1b890fa 100644 --- a/src/plugins/template/tests/functional/op_reference/istft.cpp +++ b/src/plugins/template/tests/functional/op_reference/istft.cpp @@ -48,11 +48,16 @@ class ReferenceISTFT : public testing::TestWithParam, public refere void SetUp() override { const auto& params = GetParam(); function = CreateFunction(params); - inputData = {params.signal.data, - params.window.data, - params.frame_size.data, - params.frame_step.data, - params.length.data}; + if (shape_size(params.length.shape) == 0) { // Ignore signal length + inputData = {params.signal.data, params.window.data, params.frame_size.data, params.frame_step.data}; + } else { + inputData = {params.signal.data, + params.window.data, + params.frame_size.data, + params.frame_step.data, + params.length.data}; + } + refOutData = {params.expected_tensor.data}; abs_threshold = 1e-5f; } @@ -97,16 +102,28 @@ class ReferenceISTFT : public testing::TestWithParam, public refere std::make_shared(params.frame_step.type, params.frame_step.shape); const auto in_length = std::make_shared(params.length.type, params.length.shape); - const auto ISTFT = std::make_shared(in_signal, - in_window, - in_frame_size, - in_frame_step, - in_length, - params.center, - params.normalized); - return std::make_shared( - ISTFT->outputs(), - ov::ParameterVector{in_signal, in_window, in_frame_size, in_frame_step, in_length}); + std::shared_ptr ISTFT; + if (shape_size(params.length.shape) == 0) { + ISTFT = std::make_shared(in_signal, + in_window, + in_frame_size, + in_frame_step, + params.center, + params.normalized); + return std::make_shared(ISTFT->outputs(), + ov::ParameterVector{in_signal, in_window, in_frame_size, in_frame_step}); + } else { + ISTFT = std::make_shared(in_signal, + in_window, + in_frame_size, + in_frame_step, + in_length, + params.center, + params.normalized); + return std::make_shared( + ISTFT->outputs(), + ov::ParameterVector{in_signal, in_window, in_frame_size, in_frame_step, in_length}); + } } }; @@ -343,7 +360,7 @@ std::vector generateISTFTParams() { 2.6093e-01, 5.9605e-08, 4.1264e-01, -1.0292e-01, -2.4602e-01, 2.0872e-03, -1.7790e-01, 7.3688e-02, 6.4488e-02, 0.0000e+00, 4.1084e-01, 0.0000e+00, -2.3377e-01, 0.0000e+00, -3.6569e-01, 0.0000e+00}); - reference_tests::Tensor auto_length(Shape{}, IT, std::vector{-1}); + reference_tests::Tensor auto_length(Shape{0}, IT, std::vector{}); reference_tests::Tensor length_16(Shape{}, IT, std::vector{16}); reference_tests::Tensor length_39(Shape{}, IT, std::vector{39}); reference_tests::Tensor length_48(Shape{}, IT, std::vector{48}); @@ -430,6 +447,15 @@ std::vector generateISTFTParams() { false, signal_2_48, "basic_2D_transp_two_win_step_16_center"); + params.emplace_back(output_stft_2_9_4_2_transp_win_two_center, + two_window_16, + frame_size_16, + frame_step_16, + length_48, + true, + false, + signal_2_48, + "basic_2D_transp_two_win_step_16_center_length_48"); params.emplace_back(output_stft_9_4_2_transp_win_two_center_norm, two_window_16, frame_size_16, From ee66cde0b87fa8b7e9437041c918c79dab14226b Mon Sep 17 00:00:00 2001 From: mitruska Date: Fri, 24 Jan 2025 19:23:08 +0100 Subject: [PATCH 7/7] Add tests for bigger length --- .../tests/functional/op_reference/istft.cpp | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/plugins/template/tests/functional/op_reference/istft.cpp b/src/plugins/template/tests/functional/op_reference/istft.cpp index 933721a1b890fa..107f486ead98c9 100644 --- a/src/plugins/template/tests/functional/op_reference/istft.cpp +++ b/src/plugins/template/tests/functional/op_reference/istft.cpp @@ -176,6 +176,28 @@ std::vector generateISTFTParams() { 0.9590, 0.7095, -0.2771, -0.9758, -0.6608, 0.3406, 0.9882, 0.6092, -0.4027, -0.9962, -0.5549, 0.4629, 0.9998, 0.4981, -0.5211, -0.9989, -0.4390, 0.5769, 0.9936}); + reference_tests::Tensor signal_55( + Shape{55}, + ET, + std::vector{-0.9511, -0.1861, 0.7722, 0.9283, 0.1200, -0.8129, -0.9014, -0.0534, 0.8500, 0.8704, + -0.0134, -0.8833, -0.8356, 0.0801, 0.9126, 0.7971, -0.1465, -0.9379, -0.7550, 0.2123, + 0.9590, 0.7095, -0.2771, -0.9758, -0.6608, 0.3406, 0.9882, 0.6092, -0.4027, -0.9962, + -0.5549, 0.4629, 0.9998, 0.4981, -0.5211, -0.9989, -0.4390, 0.5769, 0.9936, 0.3780, + -0.6302, -0.9838, -0.3154, 0.6806, 0.9696, 0.2513, -0.7281, -0.9511, -0.7281, 0.2513, + 0.9696, 0.6806, -0.3154, -0.9838, -0.6302}); + + reference_tests::Tensor signal_60( + Shape{60}, + ET, + std::vector{-0.9511, -0.1861, 0.7722, 0.9283, 0.1200, -0.8129, -0.9014, -0.0534, + 0.8500, 0.8704, -0.0134, -0.8833, -0.8356, 0.0801, 0.9126, 0.7971, + -0.1465, -0.9379, -0.7550, 0.2123, 0.9590, 0.7095, -0.2771, -0.9758, + -0.6608, 0.3406, 0.9882, 0.6092, -0.4027, -0.9962, -0.5549, 0.4629, + 0.9998, 0.4981, -0.5211, -0.9989, -0.4390, 0.5769, 0.9936, 0.3780, + -0.6302, -0.9838, -0.3154, 0.6806, 0.9696, 0.2513, -0.7281, -0.9511, + -0.7281, 0.2513, 0.9696, 0.6806, -0.3154, -0.9838, -0.6302, 0.3780, + 0.0000, 0.0000, 0.0000, 0.0000}); + reference_tests::Tensor signal_1_48( signal_1_48_shape, ET, @@ -364,6 +386,8 @@ std::vector generateISTFTParams() { reference_tests::Tensor length_16(Shape{}, IT, std::vector{16}); reference_tests::Tensor length_39(Shape{}, IT, std::vector{39}); reference_tests::Tensor length_48(Shape{}, IT, std::vector{48}); + reference_tests::Tensor length_55(Shape{}, IT, std::vector{55}); + reference_tests::Tensor length_60(Shape{}, IT, std::vector{60}); std::vector params; params.emplace_back(output_stft_9_1_2_transp_win_one, @@ -466,6 +490,26 @@ std::vector generateISTFTParams() { signal_39, "basic_1D_transp_two_win_step_16_center_norm_length_39"); + params.emplace_back(output_stft_9_4_2_transp_win_two_center_norm, + two_window_16, + frame_size_16, + frame_step_16, + length_55, + true, + true, + signal_55, + "basic_1D_transp_two_win_step_16_center_norm_length_55"); + + params.emplace_back(output_stft_9_4_2_transp_win_two_center_norm, + two_window_16, + frame_size_16, + frame_step_16, + length_60, + true, + true, + signal_60, + "basic_1D_transp_two_win_step_16_center_norm_length_60"); + return params; }