Skip to content

Commit

Permalink
Add RaggedToSparse Op (#49)
Browse files Browse the repository at this point in the history
Add RaggedToSparse Operation


Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
apaniukov and rkazants authored Mar 4, 2024
1 parent 1f863af commit 0517c49
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/ov_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ OPENVINO_CREATE_EXTENSIONS(
std::make_shared<ov::OpExtension<BytesToChars>>(),
std::make_shared<ov::OpExtension<CombineSegments>>(),
std::make_shared<ov::OpExtension<RaggedToDense>>(),
std::make_shared<ov::OpExtension<RaggedToSparse>>(),
std::make_shared<ov::OpExtension<VocabEncoder>>(),
std::make_shared<ov::OpExtension<VocabDecoder>>(),
std::make_shared<ov::OpExtension<CharsToBytes>>(),
Expand Down
47 changes: 47 additions & 0 deletions src/ragged_to_sparse.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <openvino/op/constant.hpp>

#include "ragged_to_sparse.hpp"
#include "utils.hpp"

using namespace ov;
using op::v0::Constant;

void RaggedToSparse::validate_and_infer_types() {
OPENVINO_ASSERT(get_input_size() == 2);

auto starts_type = this->get_input_element_type(0);
auto ends_type = this->get_input_element_type(1);

OPENVINO_ASSERT(starts_type == element::i32, "Expected an i32 starts tensor ragged representation.");
OPENVINO_ASSERT(ends_type == element::i32, "Expected an i32 starts tensor ragged representation.");
OPENVINO_ASSERT(get_input_partial_shape(0) == get_input_partial_shape(1), "starts and ends tensors should be the same shape.");

set_output_type(0, get_input_element_type(0), PartialShape({Dimension::dynamic(), 2}));
}


bool RaggedToSparse::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
auto begins = inputs[0].data<const int32_t>();
auto ends = inputs[1].data<const int32_t>();

const auto last_element_index = inputs[1].get_size() - 1;
const uint64_t num_elements = static_cast<uint64_t>(ends[last_element_index] - begins[0]);
outputs[0].set_shape(ov::Shape{num_elements, 2});

auto batch_size = inputs[0].get_size();

auto output = outputs[0].data<int32_t>();
size_t current_idx = 0;
for (size_t i = 0; i < batch_size; ++i) {
auto num_row_elements = ends[i] - begins[i];
for (size_t j = 0; j < num_row_elements; ++j) {
output[current_idx++] = i;
output[current_idx++] = j;
};
};
return true;
}
36 changes: 36 additions & 0 deletions src/ragged_to_sparse.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <openvino/op/op.hpp>

// Takes one ragged dimension (starts, ends) and produces 2D tensor of sparse coordinates [(row, col), ...]
class RaggedToSparse : public ov::op::Op {
public:
OPENVINO_OP("RaggedToSparse");

RaggedToSparse () = default;

RaggedToSparse(const ov::OutputVector& arguments) :
ov::op::Op(arguments) {
constructor_validate_and_infer_types();
}

void validate_and_infer_types() override;

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
return std::make_shared<RaggedToSparse>(inputs);
}

bool visit_attributes(ov::AttributeVisitor& visitor) override {
return true;
}

bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override;

bool has_evaluate() const override {
return true;
}
};
3 changes: 2 additions & 1 deletion src/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
#include "wordpiece_tokenizer.hpp"
#include "bpe_tokenizer.hpp"
#include "ragged_to_dense.hpp"
#include "vocab_encoder.hpp"
#include "ragged_to_sparse.hpp"
#include "vocab_decoder.hpp"
#include "vocab_encoder.hpp"
#include "chars_to_bytes.hpp"

#include "tensorflow_translators.hpp"

0 comments on commit 0517c49

Please sign in to comment.