Skip to content

Commit

Permalink
Updated quantile.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
geeky33 authored Jan 22, 2025
1 parent 4f15abe commit 156d2f0
Showing 1 changed file with 38 additions and 43 deletions.
81 changes: 38 additions & 43 deletions src/frontends/pytorch/src/op/quantile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,64 +16,59 @@
#include "openvino/op/minimum.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_quantile(const NodeContext& context) {
num_inputs_check(context, 2, 4);
num_inputs_check(context, 2, 5);

auto input = context.get_input(0);
auto quantiles = context.get_input(1);
auto q = context.get_input(1); // Quantile(s), can be float or tensor

auto dim = context.input_is_none(2) ? -1 : context.get_input<int64_t>(2);
auto keepdim = context.input_is_none(3) ? false : context.get_input<bool>(3);
auto interpolation = context.input_is_none(4) ? "linear" : context.get_input<std::string>(4);


if (dim == -1) {
input = context.mark_node(std::make_shared<v0::Reshape>(
input, context.mark_node(v0::Constant::create(element::i64, {1}, {-1})), true));
input, context.mark_node(std::make_shared<v0::Range>(0, input.get_shape().size(), 1)), true));
dim = 0;
}

auto sort_result = context.mark_node(std::make_shared<v0::Sort>(input, dim, true));
auto sorted_tensor = sort_result->output(0);

auto input_shape = context.mark_node(std::make_shared<v0::ShapeOf>(input));
auto dim_size = context.mark_node(std::make_shared<v0::Gather>(
input_shape, context.mark_node(v0::Constant::create(element::i64, {}, {dim})),
v0::Constant::create(element::i64, {}, {0})));

auto scaled_q = context.mark_node(std::make_shared<v1::Multiply>(
quantiles, context.mark_node(std::make_shared<v1::Subtract>(
dim_size, v0::Constant::create(element::i64, {}, {1})))));
auto lower_indices = context.mark_node(std::make_shared<v0::Floor>(scaled_q));
auto upper_indices = context.mark_node(std::make_shared<v1::Add>(
lower_indices, v0::Constant::create(element::i64, {}, {1})));

lower_indices = context.mark_node(std::make_shared<v1::Maximum>(
lower_indices, v0::Constant::create(element::i64, {}, {0})));
upper_indices = context.mark_node(std::make_shared<v1::Minimum>(
upper_indices, context.mark_node(std::make_shared<v1::Subtract>(
dim_size, v0::Constant::create(element::i64, {}, {1})))));

auto lower_values = context.mark_node(std::make_shared<v1::Gather>(sorted_tensor, lower_indices, dim));
auto upper_values = context.mark_node(std::make_shared<v1::Gather>(sorted_tensor, upper_indices, dim));

auto weights = context.mark_node(std::make_shared<v1::Subtract>(scaled_q, lower_indices));

auto result = context.mark_node(std::make_shared<v1::Add>(
lower_values, context.mark_node(std::make_shared<v1::Multiply>(weights, context.mark_node(std::make_shared<v1::Subtract>(upper_values, lower_values))))));

auto sorted = context.mark_node(std::make_shared<v0::Sort>(input, dim, true)); // Ascending order

auto dim_size = input.get_shape()[dim];

auto indices = context.mark_node(std::make_shared<v0::Multiply>(q, dim_size - 1));
auto lower_indices = context.mark_node(std::make_shared<v0::Floor>(indices));
auto upper_indices = context.mark_node(std::make_shared<v1::Add>(lower_indices, 1));
auto weights = context.mark_node(std::make_shared<v1::Subtract>(indices, lower_indices));
auto lower_values = context.mark_node(std::make_shared<v1::Gather>(sorted, lower_indices, dim));
auto upper_values = context.mark_node(std::make_shared<v1::Gather>(sorted, upper_indices, dim));

Output<Node> result;
if (interpolation == "linear") {
result = context.mark_node(std::make_shared<v1::Add>(
lower_values, context.mark_node(std::make_shared<v1::Multiply>(weights, upper_values))));
} else if (interpolation == "lower") {
result = lower_values;
} else if (interpolation == "higher") {
result = upper_values;
} else if (interpolation == "nearest") {
auto nearest_indices = context.mark_node(std::make_shared<v0::Round>(indices));
result = context.mark_node(std::make_shared<v1::Gather>(sorted, nearest_indices, dim));
} else if (interpolation == "midpoint") {
result = context.mark_node(std::make_shared<v1::Add>(
lower_values, context.mark_node(std::make_shared<v1::Multiply>(
context.mark_node(std::make_shared<v0::Constant>(element::f32, Shape{}, 0.5)),
context.mark_node(std::make_shared<v1::Subtract>(upper_values, lower_values))))));
} else {
throw std::runtime_error("Unsupported interpolation method: " + interpolation);
}
if (!keepdim) {
auto input_shape = context.mark_node(std::make_shared<v0::ShapeOf>(input));
auto output_shape = context.mark_node(std::make_shared<v1::Gather>(
input_shape,
context.mark_node(v0::Constant::create(element::i64, {1}, {dim})),
v0::Constant::create(element::i64, {}, {0})));
result = context.mark_node(std::make_shared<v0::Reshape>(result, output_shape, true));
auto reshape_dims = input.get_shape();
reshape_dims.erase(reshape_dims.begin() + dim);
result = context.mark_node(std::make_shared<v0::Reshape>(result, reshape_dims, true));
}

return {result};
Expand Down

0 comments on commit 156d2f0

Please sign in to comment.