Skip to content

Commit

Permalink
Merge branch 'master' into logaddexp
Browse files Browse the repository at this point in the history
  • Loading branch information
itsbharatj authored Jan 20, 2025
2 parents 1b9f935 + d757efd commit a7c2f28
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 49 deletions.
17 changes: 8 additions & 9 deletions samples/cpp/benchmark_app/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,7 @@ int main(int argc, char* argv[]) {
}
}
auto result = std::find_if(config.begin(), config.end(), [&](const std::pair<std::string, ov::AnyMap>& item) {
if (device_name.find(item.first) == 0)
return true;
return false;
return device_name.find(item.first) == 0;
});
ov::AnyMap device_config = {};
if (result != config.end())
Expand All @@ -548,6 +546,11 @@ int main(int argc, char* argv[]) {
}

bool isDynamicNetwork = false;
auto areNetworkInputsDynamic = [](const benchmark_app::InputsInfo& input_info) {
return std::any_of(input_info.begin(), input_info.end(), [](const auto& info) {
return info.second.partialShape.is_dynamic();
});
};

if (FLAGS_load_from_file && !isNetworkCompiled) {
if (!FLAGS_mean_values.empty() || !FLAGS_scale_values.empty()) {
Expand Down Expand Up @@ -722,12 +725,7 @@ int main(int argc, char* argv[]) {
model = preproc.build();

// Check if network has dynamic shapes
auto input_info = app_inputs_info[0];
isDynamicNetwork = std::any_of(input_info.begin(),
input_info.end(),
[](const std::pair<std::string, benchmark_app::InputInfo>& i) {
return i.second.partialShape.is_dynamic();
});
isDynamicNetwork = areNetworkInputsDynamic(app_inputs_info.at(0));

topology_name = model->get_friendly_name();

Expand Down Expand Up @@ -789,6 +787,7 @@ int main(int argc, char* argv[]) {
FLAGS_scale_values,
FLAGS_mean_values,
compiledModel.inputs());
isDynamicNetwork = areNetworkInputsDynamic(app_inputs_info.at(0));

batchSize = get_batch_size(app_inputs_info.at(0));
warn_if_no_batch(app_inputs_info.at(0));
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/index_put_.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_index_put_(const NodeContext& context) {
OutputVector translate_index_put(const NodeContext& context) {
// Pass as PtFrameworkNode to register as `inplace_op`. Conversion to OV operators is done as transformation.
auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
return {context.mark_node(node)};
Expand Down
8 changes: 6 additions & 2 deletions src/frontends/pytorch/src/op/log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,20 @@ OutputVector translate_log10(const NodeContext& context) {
};

OutputVector translate_logsumexp(const NodeContext& context) {
num_inputs_check(context, 1, 2);
num_inputs_check(context, 1, 3);
auto input = context.get_input(0);
ov::Output<ov::Node> dim;
if (!context.input_is_none(1)) {
dim = context.get_input(1);
} else {
dim = context.mark_node(get_axes_range(context, 0));
}
bool keepdim = false;
if (!context.input_is_none(2)) {
keepdim = context.const_input<bool>(2);
}
auto exp = context.mark_node(std::make_shared<v0::Exp>(input));
auto sum = context.mark_node(std::make_shared<v1::ReduceSum>(exp, dim, false));
auto sum = context.mark_node(std::make_shared<v1::ReduceSum>(exp, dim, keepdim));
auto log = context.mark_node(std::make_shared<v0::Log>(sum));
return {log};
};
Expand Down
8 changes: 6 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ OP_CONVERTER(translate_index);
OP_CONVERTER(translate_index_add);
OP_CONVERTER(translate_index_copy_);
OP_CONVERTER(translate_index_fill_);
OP_CONVERTER(translate_index_put_);
OP_CONVERTER(translate_index_put);
OP_CONVERTER(translate_index_select);
OP_CONVERTER(translate_instance_norm);
OP_CONVERTER(translate_int);
Expand Down Expand Up @@ -433,6 +433,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::col2im", op::translate_col2im},
{"aten::complex", op::translate_complex},
{"aten::concat", op::translate_cat},
{"aten::concatenate", op::translate_cat},
{"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail,
// we assume all tensors are contiguous
{"aten::conv_transpose1d", op::translate_conv_transposend},
Expand Down Expand Up @@ -465,6 +466,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::empty", op::translate_empty},
{"aten::empty_like", op::translate_empty_like},
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::equal", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::erf", op::translate_erf},
{"aten::erfc", op::translate_erfc},
{"aten::exp", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>, 1>},
Expand Down Expand Up @@ -508,7 +510,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
// aten::index - Supported in limited set of patterns
{"aten::index_copy_", op::inplace_op<op::translate_index_copy_>},
{"aten::index_fill_", op::inplace_op<op::translate_index_fill_>},
{"aten::index_put_", op::inplace_op<op::translate_index_put_>},
{"aten::index_put", op::translate_index_put},
{"aten::index_add", op::translate_index_add},
{"aten::index_select", op::translate_index_select},
{"aten::instance_norm", op::translate_instance_norm},
Expand Down Expand Up @@ -552,6 +554,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::log2_", op::inplace_op<op::translate_log2>},
{"aten::log10", op::optional_out<op::translate_log10, 1>},
{"aten::log10_", op::inplace_op<op::translate_log10>},
{"aten::logsumexp", op::translate_logsumexp},
{"aten::lstm", op::translate_lstm},
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
{"aten::masked_fill", op::translate_masked_fill},
Expand Down Expand Up @@ -716,6 +719,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"ov_ext::embedding", op::translate_embedding_ext},
{"ov_ext::conv1d", op::translate_conv1d_ext},
{"ov_ext::linear", op::translate_linear},
{"prim::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
{"prim::Constant", op::translate_constant},
{"prim::device", op::translate_constant},
// prim::DictConstruct - Supported in limited set of patterns
Expand Down
5 changes: 2 additions & 3 deletions src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ void check_level_zero_attributes_match(const IODescriptor& ioDescriptor, const A
'\n' + "Given: " + std::to_string(ovDimensions.size()));

for (size_t index = 0; index < ovDimensions.size(); ++index) {
OPENVINO_ASSERT(
ioDescriptor.shapeFromCompiler.is_dynamic() || ovDimensions[index] == zeDescriptor.info.dims[index],
"Shape mismatch for input/output named " + ioDescriptor.nameFromCompiler);
OPENVINO_ASSERT(ovDimensions[index] == zeDescriptor.info.dims[index],
"Shape mismatch for input/output named " + ioDescriptor.nameFromCompiler);
}
for (size_t index = ovDimensions.size(); index < ZE_MAX_GRAPH_ARGUMENT_DIMENSIONS_SIZE; ++index) {
OPENVINO_ASSERT(zeDescriptor.info.dims[index] == 0 || zeDescriptor.info.dims[index] == 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
#include "intel_npu/utils/zero/zero_api.hpp"
#include "intel_npu/utils/zero/zero_result.hpp"
#include "intel_npu/utils/zero/zero_wrappers.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/core/model.hpp"
#include "openvino/core/partial_shape.hpp"

#define NotSupportQuery(T) (T <= ZE_GRAPH_EXT_VERSION_1_2)

Expand Down Expand Up @@ -400,7 +402,8 @@ ze_graph_handle_t ZeGraphExtWrappers::getGraphHandle(const std::vector<uint8_t>&
static IODescriptor getIODescriptor(const ze_graph_argument_properties_3_t& arg,
const std::optional<ze_graph_argument_metadata_t>& metadata) {
ov::element::Type_t precision = toOVElementType(arg.devicePrecision);
ov::Shape shapeFromCompiler, shapeFromIRModel;
ov::Shape shapeFromCompiler;
ov::PartialShape shapeFromIRModel;
std::unordered_set<std::string> outputTensorNames;

for (uint32_t id = 0; id < arg.associated_tensor_names_count; id++) {
Expand All @@ -410,8 +413,17 @@ static IODescriptor getIODescriptor(const ze_graph_argument_properties_3_t& arg,
shapeFromCompiler.push_back(arg.dims[id]);
}
if (metadata.has_value()) {
const auto dynamicDim = std::numeric_limits<uint64_t>::max();
shapeFromIRModel.reserve(metadata->shape_size);
for (uint32_t id = 0; id < metadata->shape_size; id++) {
shapeFromIRModel.push_back(metadata->shape[id]);
if (metadata->shape[id] != dynamicDim) {
shapeFromIRModel.push_back(metadata->shape[id]);
} else {
// lower bound is ignored, so we set it to 1 just to satisfy the Dimension constructor,
// upper bound is set to the value from shapeFromCompiler as it is filled with upper bounds
// in case of dynamic dimensions
shapeFromIRModel.push_back(ov::Dimension(1, shapeFromCompiler[id]));
}
}
}

Expand All @@ -433,7 +445,7 @@ static IODescriptor getIODescriptor(const ze_graph_argument_properties_3_t& arg,

return {std::move(nameFromCompiler),
precision,
std::move(shapeFromCompiler),
shapeFromCompiler,
isStateInput,
isStateOutput,
isShapeTensor,
Expand Down
26 changes: 16 additions & 10 deletions src/plugins/intel_npu/tools/single-image-test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1569,8 +1569,8 @@ std::pair<TensorMap, ProfVec> runInfer(ov::InferRequest& inferRequest, ov::Compi

TensorMap out;
for (const auto& outputInfo : compiledModel.outputs()) {
const std::string layer_name = outputInfo.get_any_name();
out.insert({layer_name, inferRequest.get_tensor(layer_name)});
const std::string layerName = outputInfo.get_any_name();
out.insert({layerName, inferRequest.get_tensor(layerName)});
}

ProfVec profData{};
Expand Down Expand Up @@ -1807,11 +1807,17 @@ bool testMeanIoU(const TensorMap& outputs, const TensorMap& references, const La
}

static ov::Shape parseDataShape(const std::string& dataShapeStr) {
std::vector<size_t> dataShape;
std::istringstream ss(dataShapeStr);
std::string token;
while (std::getline(ss, token, ',')) {
dataShape.push_back(std::stoul(token));
std::vector<uint64_t> dataShape;
std::stringstream ss(dataShapeStr);

char ch; // To discard non-numeric characters
int64_t dim;
while (ss >> ch) {
if (std::isdigit(ch)) {
ss.putback(ch);
ss >> dim;
dataShape.push_back(dim);
}
}
return ov::Shape(dataShape);
}
Expand Down Expand Up @@ -1906,11 +1912,11 @@ static int runSingleImageTest() {
auto model = core.read_model(FLAGS_network);
nameIOTensors(model);

auto inputs_info = std::const_pointer_cast<ov::Model>(model)->inputs();
InputsInfo info_map;
auto inputsInfo = std::const_pointer_cast<ov::Model>(model)->inputs();
InputsInfo infoMap;

std::cout << "Performing reshape" << std::endl;
reshape(std::move(inputs_info), info_map, model, FLAGS_shape,
reshape(std::move(inputsInfo), infoMap, model, FLAGS_shape,
FLAGS_override_model_batch_size, FLAGS_device);

ov::preprocess::PrePostProcessor ppp(model);
Expand Down
34 changes: 34 additions & 0 deletions tests/layer_tests/pytorch_tests/test_logsumexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class aten_logsumexp(torch.nn.Module):
def __init__(self, dim, keepdim) -> None:
super().__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, input_tensor):
return torch.logsumexp(input_tensor, dim=self.dim, keepdim=self.keepdim)


class TestLogsumexp(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 5, 9, 7),)

@pytest.mark.parametrize("dim", [
0, 1, 2, 3, -1, -2, -3, -4
])
@pytest.mark.parametrize("keepdim", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_fx_backend
def test_logsumexp(self, dim, keepdim, ie_device, precision, ir_version):
self._test(aten_logsumexp(dim, keepdim), None, "aten::logsumexp",
ie_device, precision, ir_version)
27 changes: 24 additions & 3 deletions tests/layer_tests/pytorch_tests/test_unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@

class unary_op_net(torch.nn.Module):
def __init__(self, op, dtype):
super(unary_op_net, self).__init__()
super().__init__()
self.dtype = dtype
self.op = op

Expand All @@ -87,7 +87,7 @@ def forward(self, x):

class unary_op_out_net(torch.nn.Module):
def __init__(self, op, dtype):
super(unary_op_out_net, self).__init__()
super().__init__()
self.dtype = dtype
self.op = op

Expand All @@ -101,7 +101,7 @@ def forward(self, x):

class unary_func_op_inplace_net(torch.nn.Module):
def __init__(self, op, dtype):
super(unary_func_op_inplace_net, self).__init__()
super().__init__()
self.dtype = dtype
self.op = op

Expand All @@ -111,6 +111,17 @@ def forward(self, x):
return y, x1


class prim_abs_net(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.dtype = dtype

def forward(self, x):
x1 = x.to(self.dtype)
y = abs(x1)
return y, x1


class TestUnaryOp(PytorchLayerTest):
def _prepare_input(self):
# random number in range [1, 11)
Expand Down Expand Up @@ -265,3 +276,13 @@ def test_unary_func_op_inplace(self, op_type, dtype, ie_device, precision, ir_ve
self.dtype = dtype
self._test(unary_func_op_inplace_net(OPS[op_type], dtype), None, op_type + "_",
ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.int8, torch.uint8, torch.int32, torch.int64])
def test_prim_abs(self, dtype, ie_device, precision, ir_version):
self.dtype = dtype
self._test(prim_abs_net(dtype), None, "prim::abs",
ie_device, precision, ir_version)
16 changes: 0 additions & 16 deletions tools/constraints.txt

This file was deleted.

0 comments on commit a7c2f28

Please sign in to comment.