Skip to content

Commit

Permalink
Init ISTFT op
Browse files Browse the repository at this point in the history
  • Loading branch information
mitruska committed Jan 21, 2025
1 parent 54a4f1c commit 311e3d1
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 1 deletion.
51 changes: 51 additions & 0 deletions src/core/include/openvino/op/istft.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"

namespace ov {
namespace op {
namespace v16 {
/// \brief An operation ISTFT that computes the Inverse Short Time Fourier Transform.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API ISTFT : public Op {
public:
OPENVINO_OP("ISTFT", "opset16");
ISTFT() = default;

/// \brief Constructs an ISTFT operation.
///
/// \param data Input data
/// \param window Window values applied in STFT
/// \param frame_size Scalar value representing the size of Fourier Transform
/// \param frame_step The distance (number of samples) between successive window frames
/// \param length The length of the original signal
/// \param center Flag signaling if the signal input has been padded before STFT
/// \param normalized Flag signaling if the STFT result has been normalized.
ISTFT(const Output<Node>& data,
const Output<Node>& window,
const Output<Node>& frame_size,
const Output<Node>& frame_step,
const Output<Node>& length,
const bool center,
const bool normalized);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool get_center() const;

bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;

private:
bool m_center = false;
bool m_normalized = false;
};
} // namespace v16
} // namespace op
} // namespace ov
1 change: 1 addition & 0 deletions src/core/include/openvino/op/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
#include "openvino/op/is_finite.hpp"
#include "openvino/op/is_inf.hpp"
#include "openvino/op/is_nan.hpp"
#include "openvino/op/istft.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/less_eq.hpp"
#include "openvino/op/log.hpp"
Expand Down
1 change: 1 addition & 0 deletions src/core/include/openvino/opsets/opset16_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ _OPENVINO_OP_REG(ShapeOf, ov::op::v3)

// New operations added in opset16
_OPENVINO_OP_REG(Identity, ov::op::v16)
_OPENVINO_OP_REG(ISTFT, ov::op::v16)
112 changes: 112 additions & 0 deletions src/core/shape_inference/include/istft_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "dimension_util.hpp"
#include "openvino/op/istft.hpp"
#include "utils.hpp"

namespace ov {
namespace op {
namespace v16 {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const ISTFT* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
using TDim = typename TRShape::value_type;
using TDimVal = typename TDim::value_type;

NODE_VALIDATION_CHECK(op, input_shapes.size() == 5);

const auto& data_shape = input_shapes[0];
const auto& window_shape = input_shapes[1];
const auto& frame_size_shape = input_shapes[2];
const auto& frame_step_shape = input_shapes[3];
const auto& length_shape = input_shapes[4];

const auto data_shape_rank = data_shape.rank();
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
data_shape_rank.compatible(3) || data_shape_rank.compatible(4),
"The shape of data must be 3D [signal_size] or 4D [batch, signal_size].");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
window_shape.rank().compatible(1),
"The shape of window must be 1D [window_size].");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
frame_size_shape.rank().compatible(0),
"The shape of frame_size must be a scalar.");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
frame_step_shape.rank().compatible(0),
"The shape of frame_step must be a scalar.");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
length_shape.rank().compatible(0),
"The shape of length input must be a scalar.");

if (data_shape_rank.is_dynamic()) {
return {data_shape};
}

const auto frame_size = get_input_const_data_as<TRShape, int64_t>(op, 2, ta);
const auto frame_step = get_input_const_data_as<TRShape, int64_t>(op, 3, ta);

const auto is_data_3D = data_shape.size() == 3;
if (!frame_size || !frame_step) {
if (is_data_3D) {
return {TRShape{TDim(ov::util::dim::inf_bound)}};
} else {
return {TRShape{data_shape[0], TDim(ov::util::dim::inf_bound)}};
}
}

const auto& frame_size_val = (*frame_size)[0];
const auto& frame_step_val = (*frame_step)[0];

NODE_SHAPE_INFER_CHECK(op,
input_shapes,
0 < frame_size_val,
"Provided frame size is ",
frame_size_val,
" but must be greater than zero.");

NODE_SHAPE_INFER_CHECK(op,
input_shapes,
0 < frame_step_val,
"Provided frame step is ",
frame_step_val,
" but must be greater than zero.");

const bool is_win_shape_correct =
window_shape.is_dynamic() || (TDimVal{0} < window_shape[0].get_length() &&
window_shape[0].get_length() <= static_cast<TDimVal>(frame_size_val));
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
is_win_shape_correct,
"Window input dimension must be in range [1, ",
frame_size_val,
"].");

int64_t frames_axis = 1 + (is_data_3D ? 0 : 1);
const TDim& num_frames_dim = data_shape[frames_axis];
TDim signal_length = (num_frames_dim - 1) * frame_step_val;
if (!op->get_center()) {
signal_length += frame_size_val;
}

std::vector<TRShape> output_shapes;
output_shapes.emplace_back(TRShape{std::move(signal_length)});

if (!is_data_3D) {
const auto& batch_dim = data_shape[0];
output_shapes[0].insert(output_shapes[0].begin(), batch_dim);
}
return output_shapes;
}
} // namespace v16
} // namespace op
} // namespace ov
123 changes: 123 additions & 0 deletions src/core/src/op/istft.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/istft.hpp"

#include <memory>

#include "istft_shape_inference.hpp"
#include "itt.hpp"
// #include "openvino/reference/istft.hpp"

namespace ov {
namespace op {
namespace v16 {
namespace {
void check_int_input_at(const Node* op, size_t input_idx) {
const auto& in_type = op->get_input_element_type(input_idx);
const auto has_valid_type = in_type.is_dynamic() || in_type == element::i32 || in_type == element::i64;
NODE_VALIDATION_CHECK(op, has_valid_type, "Expected i32 or i64 type of the input at port: ", input_idx);
}
} // namespace
ISTFT::ISTFT(const Output<Node>& data,
const Output<Node>& window,
const Output<Node>& frame_size,
const Output<Node>& frame_step,
const Output<Node>& length,
const bool center,
const bool normalized)
: Op({data, window, frame_size, frame_step, length}),
m_center(center),
m_normalized(normalized) {
constructor_validate_and_infer_types();
}

std::shared_ptr<Node> ISTFT::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v16_ISTFT_clone_with_new_inputs);
check_new_args_count(this, new_args);

return std::make_shared<ISTFT>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
m_center,
m_normalized);
}

bool ISTFT::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v16_ISTFT_visit_attributes);
visitor.on_attribute("center", m_center);
visitor.on_attribute("normalized", m_normalized);
return true;
}

void ISTFT::validate_and_infer_types() {
OV_OP_SCOPE(v16_ISTFT_validate_and_infer_types);
NODE_VALIDATION_CHECK(this, get_input_size() == 5, "Expected 5 inputs to be provided.");

auto signal_type = get_input_element_type(0);
const auto& window_type = get_input_element_type(1);

const auto has_valid_signal_type = signal_type.is_dynamic() || signal_type.is_real();
NODE_VALIDATION_CHECK(this, has_valid_signal_type, "Expected floating point type of the 'signal' input.");

const auto has_valid_window_type =
window_type.is_dynamic() ||
(window_type.is_real() && element::Type::merge(signal_type, window_type, signal_type));
NODE_VALIDATION_CHECK(this,
has_valid_window_type,
"Expected floating point type of the 'window' input, matching the type of `signal` input.");

check_int_input_at(this, 2);
check_int_input_at(this, 3);
check_int_input_at(this, 4);

const auto input_shapes = ov::util::get_node_input_partial_shapes(*this);
const auto output_shapes = shape_infer(this, input_shapes);

set_output_type(0, signal_type, output_shapes[0]);
}

bool ISTFT::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v16_ISTFT_evaluate);
OPENVINO_ASSERT(outputs.size() == 1);
OPENVINO_ASSERT(inputs.size() == 5);

const auto input_shapes = ov::util::get_tensors_partial_shapes(inputs);
const auto output_shape = shape_infer(this, input_shapes, make_tensor_accessor(inputs)).front().to_shape();

outputs[0].set_shape(output_shape);

const auto frame_size = ov::get_tensor_data_as<int64_t>(inputs[2]).front();
const auto frame_step = ov::get_tensor_data_as<int64_t>(inputs[3]).front();
const auto length = ov::get_tensor_data_as<int64_t>(inputs[4]).front();

// ov::reference::istft(inputs[0].data<const float>(),
// inputs[1].data<const float>(),
// outputs[0].data<float>(),
// inputs[0].get_shape(),
// inputs[1].get_shape(),
// frame_size,
// frame_step,
// length,
// m_normalized,
// m_center);
return true;
}

bool ISTFT::has_evaluate() const {
OV_OP_SCOPE(v16_ISTFT_has_evaluate);
const auto& input_0_et = get_input_element_type(0);
return input_0_et == element::f32;
}

bool ISTFT::get_center() const {
OV_OP_SCOPE(v16_ISTFT_get_center);
return m_center;
}

} // namespace v16
} // namespace op
} // namespace ov
2 changes: 1 addition & 1 deletion src/core/tests/opset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ INSTANTIATE_TEST_SUITE_P(opset,
OpsetTestParams{ov::get_opset13, 186},
OpsetTestParams{ov::get_opset14, 188},
OpsetTestParams{ov::get_opset15, 199},
OpsetTestParams{ov::get_opset16, 4}),
OpsetTestParams{ov::get_opset16, 5}),
OpsetTestNameGenerator{});

class MyOpOld : public ov::op::Op {
Expand Down
76 changes: 76 additions & 0 deletions src/core/tests/type_prop/istft.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/istft.hpp"

#include <gtest/gtest.h>

#include "common_test_utils/test_assertions.hpp"
#include "common_test_utils/type_prop.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"

namespace ov {
namespace test {

using op::v0::Constant;
using op::v0::Parameter;
using testing::HasSubstr;

class TypePropISTFTTest : public TypePropOpTest<op::v16::ISTFT> {
public:
bool center = true;
bool normalized = false;
};

TEST_F(TypePropISTFTTest, all_inputs_as_params_static_shapes) {
const auto in_data = std::make_shared<Parameter>(element::f32, PartialShape{2, 1, 4, 48});
const auto window = std::make_shared<Parameter>(element::f32, PartialShape{7});
const auto frame_size = std::make_shared<Parameter>(element::i64, PartialShape{});
const auto frame_step = std::make_shared<Parameter>(element::i64, PartialShape{});
const auto length = std::make_shared<Parameter>(element::i64, PartialShape{});

const auto op = make_op(in_data, window, frame_size, frame_step, length, center, normalized);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, -1}));
}

using STFTTestParam = std::tuple<PartialShape, PartialShape, int32_t, int32_t, bool, PartialShape>;
class TypePropISTFTTestP : public TypePropISTFTTest, public testing::WithParamInterface<STFTTestParam> {
protected:
void SetUp() override {
std::tie(signal_shape, window_shape, frame_size_val, step_size_val, center, expected_shape) = GetParam();
}
PartialShape signal_shape, window_shape, expected_shape;
int32_t frame_size_val, step_size_val;
};

INSTANTIATE_TEST_SUITE_P(
type_prop_stft_shape,
TypePropISTFTTestP,
testing::Values(
std::make_tuple(PartialShape{16}, PartialShape{16}, 16, 16, false, PartialShape{9, 1, 2}), // frames at 1
std::make_tuple(PartialShape{48}, PartialShape{16}, 16, 16, false, PartialShape{9, 3, 2}),
std::make_tuple(PartialShape{56}, PartialShape{7}, 11, 3, false, PartialShape{6, 16, 2}),

std::make_tuple(PartialShape::dynamic(), PartialShape::dynamic(), 11, 3, true, PartialShape::dynamic())),
testing::PrintToStringParamName());

TEST_P(TypePropISTFTTestP, istft_shapes) {
const auto in_data = std::make_shared<Parameter>(element::f32, expected_shape);
const auto window = std::make_shared<Parameter>(element::f32, window_shape);
const auto frame_size = Constant::create<int32_t>(element::i32, {}, {frame_size_val});
const auto frame_step = Constant::create<int32_t>(element::i32, {}, {step_size_val});
const auto length = std::make_shared<Parameter>(element::i32, Shape{});

const auto op = make_op(in_data, window, frame_size, frame_step, length, center, false);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), signal_shape);
}

} // namespace test
} // namespace ov

0 comments on commit 311e3d1

Please sign in to comment.