Skip to content

Commit

Permalink
Created SoftmaxCrossEntropyLoss source files for ONNX FE
Browse files Browse the repository at this point in the history
  • Loading branch information
AJThePro99 committed Feb 1, 2025
1 parent 9380826 commit 451d0cb
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ namespace ov {
namespace frontend {
namespace onnx {
namespace ai_onnx{
namespace opset_12 {
OutputVector softmax_cross_entropy_loss(const Node& node);
} // namespace opset_12
namespace opset_13 {
OutputVector softmax_cross_entropy_loss(const Node& node);
namespace opset_12 {
OutputVector softmax_cross_entropy_loss(const Node& node);
} // namespace opset_12
namespace opset_13 {
OutputVector softmax_cross_entropy_loss(const Node& node);
} // namespace opset_13
} // namespace ai_onnx
} // namespace onnx
Expand Down
120 changes: 60 additions & 60 deletions src/frontends/onnx/frontend/src/op/softmax_crossentropy_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,93 @@
//

#include "core/operator_set.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/log.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/negative.hpp"
#include "openvino/op/reduce_mean.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/softmax.hpp"
#include "softmax_cross_entropy_loss.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/divide.hpp"

namespace ov {
namespace frontend {
namespace onnx {
namespace {
OutputVector impl_softmax_cross_entropy(const Node& node, int64_t axis_default) {
const auto inputs = node.get_ov_inputs();
OutputVector impl_softmax_cross_entropy(const Node& node, int64_t axis_default) {
const auto inputs = node.get_ov_inputs();

const auto scores = inputs[0];
const auto labels = inputs[1];
const auto scores = inputs[0];
const auto labels = inputs[1];

// Optional weights paramater
bool has_weights = inputs.size() > 2;
std::shared_ptr<ov::Node> weights_gather = nullptr;
// Optional weights paramater
bool has_weights = inputs.size() > 2;
std::shared_ptr<ov::Node> weights_gather = nullptr;

if (has_weights) {
const auto weights = inputs[2];
const auto axis_for_weights = ov::op::v0::Constant::create(element::i64, {}, {0});
weights_gather = std::make_shared<ov::op::v8::Gather>(weights, labels, axis_for_weights);
}
if (has_weights) {
const auto weights = inputs[2];
const auto axis_for_weights = ov::op::v0::Constant::create(element::i64, {}, {0});
weights_gather = std::make_shared<ov::op::v8::Gather>(weights, labels, axis_for_weights);
}

// Getting attributes for axis and reduction mode
const auto axis = node.get_attribute_value<int64_t>("axis", axis_default);
const auto reduction = node.get_attribute_value<std::string>("reduction", "mean");
// Getting attributes for axis and reduction mode
const auto axis = node.get_attribute_value<int64_t>("axis", axis_default);
const auto reduction = node.get_attribute_value<std::string>("reduction", "mean");

// Computing softmax & it's logarithm
const auto softmax = std::make_shared<ov::op::v8::Softmax>(scores, axis);
const auto log_softmax = std::make_shared<ov::op::v0::Log>(softmax);
// Computing softmax & it's logarithm
const auto softmax = std::make_shared<ov::op::v8::Softmax>(scores, axis);
const auto log_softmax = std::make_shared<ov::op::v0::Log>(softmax);

const auto axis_const = ov::op::v0::Constant::create(element::i64, {}, {axis});
const auto gathered = std::make_shared<ov::op::v8::Gather>(log_softmax, labels, axis_const);
const auto axis_const = ov::op::v0::Constant::create(element::i64, {}, {axis});
const auto gathered = std::make_shared<ov::op::v8::Gather>(log_softmax, labels, axis_const);

// Computing loss
std::shared_ptr<ov::Node> loss = std::make_shared<ov::op::v0::Negative>(gathered);

// Computing loss
std::shared_ptr<ov::Node> loss = std::make_shared<ov::op::v0::Negative>(gathered);
if (has_weights) {
loss = std::make_shared<ov::op::v1::Multiply>(loss, weights_gather);
}

if (has_weights) {
loss = std::make_shared<ov::op::v1::Multiply>(loss, weights_gather);
}
// applying reduction as mentioned in
// https://github.com/onnx/onnx/blob/main/docs/Changelog.md#softmaxcrossentropyloss-12

// applying reduction as mentioned in https://github.com/onnx/onnx/blob/main/docs/Changelog.md#softmaxcrossentropyloss-12

if (reduction != "None") {
// Reduce over the axis corresponding to each sample
// Reducing over axis 0, assuming the loss tensor shape is [batch_size]
const auto reduce_axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
if (reduction != "None") {
// Reduce over the axis corresponding to each sample
// Reducing over axis 0, assuming the loss tensor shape is [batch_size]
const auto reduce_axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});

if (reduction == "mean") {
if (has_weights) {
auto loss_sum = std::make_shared<ov::op::v1::ReduceSum>(loss->output(0), reduce_axis, true);
auto weight_sum = std::make_shared<ov::op::v1::ReduceSum>(weights_gather->output(0), reduce_axis, true);
loss = std::make_shared<ov::op::v1::Divide>(loss_sum, weight_sum);
} else {
loss = std::make_shared<ov::op::v1::ReduceMean>(loss->output(0), reduce_axis, true);
}
} else if (reduction == "sum") {
loss = std::make_shared<ov::op::v1::ReduceSum>(loss->output(0), reduce_axis, true);
if (reduction == "mean") {
if (has_weights) {
auto loss_sum = std::make_shared<ov::op::v1::ReduceSum>(loss->output(0), reduce_axis, true);
auto weight_sum = std::make_shared<ov::op::v1::ReduceSum>(weights_gather->output(0), reduce_axis, true);
loss = std::make_shared<ov::op::v1::Divide>(loss_sum, weight_sum);
} else {
loss = std::make_shared<ov::op::v1::ReduceMean>(loss->output(0), reduce_axis, true);
}
} else if (reduction == "sum") {
loss = std::make_shared<ov::op::v1::ReduceSum>(loss->output(0), reduce_axis, true);
}

return {loss};
}

return {loss};
}
} // namespace
namespace ai_onnx {
namespace opset_12 {
OutputVector ov::frontend::onnx::ai_onnx::opset_12::softmax_cross_entropy_loss(const Node& node) {
return impl_softmax_cross_entropy(node, 1);
}
ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_SINCE(12), ai_onnx::opset_12::softmax_cross_entropy_loss);
}
namespace opset_13 {
OutputVector ov::frontend::onnx::ai_onnx::opset_13::softmax_cross_entropy_loss(const Node& node) {
return impl_softmax_cross_entropy(node, 1);
}

ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_SINCE(13), ai_onnx::opset_13::softmax_cross_entropy_loss);
}
}
namespace opset_12 {
OutputVector ov::frontend::onnx::ai_onnx::opset_12::softmax_cross_entropy_loss(const Node& node) {
return impl_softmax_cross_entropy(node, 1);
}
ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_SINCE(12), ai_onnx::opset_12::softmax_cross_entropy_loss);
} // namespace opset_12
namespace opset_13 {
OutputVector ov::frontend::onnx::ai_onnx::opset_13::softmax_cross_entropy_loss(const Node& node) {
return impl_softmax_cross_entropy(node, 1);
}
}

ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_SINCE(13), ai_onnx::opset_13::softmax_cross_entropy_loss);
} // namespace opset_13
} // namespace ai_onnx
} // namespace onnx
} // namespace frontend
} // namespace ov

0 comments on commit 451d0cb

Please sign in to comment.