From a7fe2a2b8ac18839e825bba90326ebb8933eba7a Mon Sep 17 00:00:00 2001 From: Pawel Raasz Date: Fri, 31 Jan 2025 16:27:13 +0100 Subject: [PATCH] [core] Improve ov::element::Type to support constexpr where possible (#28643) ### Details: - Make Type class members `constexpr` if possible - Make function `ov::element::from` constexpr instead templtae specialization which allowe use it in compile time - Make global element types as `inline` to use single address - Refactor internals of element Type class - Deprecate function `Type fundamental_type_for(const Type& type);` as not used in project and there element traits can be used as alternative - Reduce code maintenance as add new types requires update less places in code. - This changes have positive impact on binary size of most of OV libraries (negative value means reduced size) | Library | Diff [KiB] | |------------|------------| | OV | -12.825 | | CPU | -0.709 | | NPU | -1.188 | | IR FE | -0.114 | | JAX FE | 0.013 | | ONNX FE | -1.954 | | Paddle FE | -0.907 | | pytorch FE | -0.947 | | TF FE | -2.313 | | TF LITE FE | -0.738 | ### Tickets: - CVS-160757 --------- Signed-off-by: Raasz, Pawel Co-authored-by: Tomasz Jankowski --- .../include/openvino/util/common_util.hpp | 7 + .../openvino/core/type/element_type.hpp | 150 +++--- src/core/src/type/element_type.cpp | 462 +++++++----------- src/inference/src/dev/core_impl.cpp | 12 - .../shared_test_classes/base/utils/ranges.hpp | 10 +- .../include/common_test_utils/type_ranges.hpp | 19 +- 6 files changed, 276 insertions(+), 384 deletions(-) diff --git a/src/common/util/include/openvino/util/common_util.hpp b/src/common/util/include/openvino/util/common_util.hpp index a11adf29cd14f1..15ec5d8f27d588 100644 --- a/src/common/util/include/openvino/util/common_util.hpp +++ b/src/common/util/include/openvino/util/common_util.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include #include @@ -173,5 +174,11 @@ inline void erase_if(Container& data, const PredicateT& predicate) { std::string filter_lines_by_prefix(const std::string& str, const std::string& prefix); +template +constexpr std::array, std::common_type_t, T>, sizeof...(Args)> make_array( + Args&&... args) { + return {std::forward(args)...}; +} + } // namespace util } // namespace ov diff --git a/src/core/include/openvino/core/type/element_type.hpp b/src/core/include/openvino/core/type/element_type.hpp index 960b318b81262c..b454d886107e7c 100644 --- a/src/core/include/openvino/core/type/element_type.hpp +++ b/src/core/include/openvino/core/type/element_type.hpp @@ -70,11 +70,11 @@ enum class Type_t { /// \ingroup ov_element_cpp_api class OPENVINO_API Type { public: - Type() = default; - Type(const Type&) = default; + constexpr Type() = default; + constexpr Type(const Type&) = default; constexpr Type(const Type_t t) : m_type{t} {} explicit Type(const std::string& type); - Type& operator=(const Type&) = default; + constexpr Type& operator=(const Type&) = default; std::string c_type_string() const; size_t size() const; size_t hash() const; @@ -95,6 +95,8 @@ class OPENVINO_API Type { // The name of this type, the enum name of this type std::string get_type_name() const; friend OPENVINO_API std::ostream& operator<<(std::ostream&, const Type&); + + OPENVINO_DEPRECATED("This function is deprecated and will be removed in 2026.0.") static std::vector get_known_types(); /// \brief Checks whether this element type is merge-compatible with `t`. @@ -137,129 +139,131 @@ using TypeVector = std::vector; /// \brief undefined element type /// \ingroup ov_element_cpp_api -constexpr Type undefined(Type_t::undefined); +inline constexpr Type undefined(Type_t::undefined); /// \brief dynamic element type /// \ingroup ov_element_cpp_api -constexpr Type dynamic(Type_t::dynamic); +inline constexpr Type dynamic(Type_t::dynamic); /// \brief boolean element type /// \ingroup ov_element_cpp_api -constexpr Type boolean(Type_t::boolean); +inline constexpr Type boolean(Type_t::boolean); /// \brief bf16 element type /// \ingroup ov_element_cpp_api -constexpr Type bf16(Type_t::bf16); +inline constexpr Type bf16(Type_t::bf16); /// \brief f16 element type /// \ingroup ov_element_cpp_api -constexpr Type f16(Type_t::f16); +inline constexpr Type f16(Type_t::f16); /// \brief f32 element type /// \ingroup ov_element_cpp_api -constexpr Type f32(Type_t::f32); +inline constexpr Type f32(Type_t::f32); /// \brief f64 element type /// \ingroup ov_element_cpp_api -constexpr Type f64(Type_t::f64); +inline constexpr Type f64(Type_t::f64); /// \brief i4 element type /// \ingroup ov_element_cpp_api -constexpr Type i4(Type_t::i4); +inline constexpr Type i4(Type_t::i4); /// \brief i8 element type /// \ingroup ov_element_cpp_api -constexpr Type i8(Type_t::i8); +inline constexpr Type i8(Type_t::i8); /// \brief i16 element type /// \ingroup ov_element_cpp_api -constexpr Type i16(Type_t::i16); +inline constexpr Type i16(Type_t::i16); /// \brief i32 element type /// \ingroup ov_element_cpp_api -constexpr Type i32(Type_t::i32); +inline constexpr Type i32(Type_t::i32); /// \brief i64 element type /// \ingroup ov_element_cpp_api -constexpr Type i64(Type_t::i64); +inline constexpr Type i64(Type_t::i64); /// \brief binary element type /// \ingroup ov_element_cpp_api -constexpr Type u1(Type_t::u1); +inline constexpr Type u1(Type_t::u1); /// \brief u2 element type /// \ingroup ov_element_cpp_api -constexpr Type u2(Type_t::u2); +inline constexpr Type u2(Type_t::u2); /// \brief u3 element type /// \ingroup ov_element_cpp_api -constexpr Type u3(Type_t::u3); +inline constexpr Type u3(Type_t::u3); /// \brief u4 element type /// \ingroup ov_element_cpp_api -constexpr Type u4(Type_t::u4); +inline constexpr Type u4(Type_t::u4); /// \brief u6 element type /// \ingroup ov_element_cpp_api -constexpr Type u6(Type_t::u6); +inline constexpr Type u6(Type_t::u6); /// \brief u8 element type /// \ingroup ov_element_cpp_api -constexpr Type u8(Type_t::u8); +inline constexpr Type u8(Type_t::u8); /// \brief u16 element type /// \ingroup ov_element_cpp_api -constexpr Type u16(Type_t::u16); +inline constexpr Type u16(Type_t::u16); /// \brief u32 element type /// \ingroup ov_element_cpp_api -constexpr Type u32(Type_t::u32); +inline constexpr Type u32(Type_t::u32); /// \brief u64 element type /// \ingroup ov_element_cpp_api -constexpr Type u64(Type_t::u64); +inline constexpr Type u64(Type_t::u64); /// \brief nf4 element type /// \ingroup ov_element_cpp_api -constexpr Type nf4(Type_t::nf4); +inline constexpr Type nf4(Type_t::nf4); /// \brief f8e4m3 element type /// \ingroup ov_element_cpp_api -constexpr Type f8e4m3(Type_t::f8e4m3); +inline constexpr Type f8e4m3(Type_t::f8e4m3); /// \brief f8e4m3 element type /// \ingroup ov_element_cpp_api -constexpr Type f8e5m2(Type_t::f8e5m2); +inline constexpr Type f8e5m2(Type_t::f8e5m2); /// \brief string element type /// \ingroup ov_element_cpp_api -constexpr Type string(Type_t::string); +inline constexpr Type string(Type_t::string); /// \brief f4e2m1 element type /// \ingroup ov_element_cpp_api -constexpr Type f4e2m1(Type_t::f4e2m1); +inline constexpr Type f4e2m1(Type_t::f4e2m1); /// \brief f8e8m0 element type /// \ingroup ov_element_cpp_api -constexpr Type f8e8m0(Type_t::f8e8m0); +inline constexpr Type f8e8m0(Type_t::f8e8m0); -template -Type from() { - OPENVINO_THROW("Unknown type"); +template +constexpr Type from() { + if constexpr (std::is_same_v || std::is_same_v) { + return boolean; + } else if constexpr (std::is_same_v) { + return f16; + } else if constexpr (std::is_same_v) { + return f32; + } else if constexpr (std::is_same_v) { + return f64; + } else if constexpr (std::is_same_v) { + return i8; + } else if constexpr (std::is_same_v) { + return i16; + } else if constexpr (std::is_same_v) { + return i32; + } else if constexpr (std::is_same_v) { + return i64; + } else if constexpr (std::is_same_v) { + return u8; + } else if constexpr (std::is_same_v) { + return u16; + } else if constexpr (std::is_same_v) { + return u32; + } else if constexpr (std::is_same_v) { + return u64; + } else if constexpr (std::is_same_v) { + return bf16; + } else if constexpr (std::is_same_v) { + return f8e4m3; + } else if constexpr (std::is_same_v) { + return f8e5m2; + } else if constexpr (std::is_same_v) { + return string; + } else if constexpr (std::is_same_v) { + return f4e2m1; + } else if constexpr (std::is_same_v) { + return f8e8m0; + } else { + OPENVINO_THROW("Unknown type"); + } } -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); -template <> -OPENVINO_API Type from(); +OPENVINO_DEPRECATED( + "This function is deprecated and will be removed in 2026.0. Use ov::fundamental_type_for instead") OPENVINO_API Type fundamental_type_for(const Type& type); OPENVINO_API @@ -281,12 +285,12 @@ template <> class OPENVINO_API AttributeAdapter : public ValueAccessor { public: OPENVINO_RTTI("AttributeAdapter"); - AttributeAdapter(ov::element::Type& value) : m_ref(value) {} + constexpr AttributeAdapter(ov::element::Type& value) : m_ref(value) {} const std::string& get() override; void set(const std::string& value) override; - operator ov::element::Type&() { + constexpr operator ov::element::Type&() { return m_ref; } diff --git a/src/core/src/type/element_type.cpp b/src/core/src/type/element_type.cpp index 3fdda4d7f55cf8..bd61d1c985ea97 100644 --- a/src/core/src/type/element_type.cpp +++ b/src/core/src/type/element_type.cpp @@ -7,11 +7,21 @@ #include #include #include +#include #include #include "openvino/core/type/element_type_traits.hpp" +#include "openvino/util/common_util.hpp" +namespace ov::element { namespace { +constexpr size_t idx(Type_t e) noexcept { + return static_cast>(e); +} + +// Update it when new type is added +constexpr size_t enum_types_size = idx(f8e8m0) + 1; + struct TypeInfo { size_t m_bitwidth; bool m_is_real; @@ -19,252 +29,179 @@ struct TypeInfo { bool m_is_quantized; const char* m_cname; const char* m_type_name; -}; - -struct ElementTypes { - struct TypeHash { - size_t operator()(ov::element::Type_t t) const { - return static_cast(t); + const char* const* aliases; + size_t alias_count; + + bool has_name(const std::string& type) const { + if (type == m_type_name) { + return true; + } else { + const auto last = aliases + alias_count; + return std::find(aliases, last, type) != last; } - }; - - using ElementsMap = std::unordered_map; -}; + } -inline TypeInfo get_type_info(ov::element::Type_t type) { - switch (type) { - case ov::element::Type_t::undefined: - return {std::numeric_limits::max(), false, false, false, "undefined", "undefined"}; - case ov::element::Type_t::dynamic: - return {0, false, false, false, "dynamic", "dynamic"}; - case ov::element::Type_t::boolean: - return {8, false, true, false, "char", "boolean"}; - case ov::element::Type_t::bf16: - return {16, true, true, false, "bfloat16", "bf16"}; - case ov::element::Type_t::f16: - return {16, true, true, false, "float16", "f16"}; - case ov::element::Type_t::f32: - return {32, true, true, false, "float", "f32"}; - case ov::element::Type_t::f64: - return {64, true, true, false, "double", "f64"}; - case ov::element::Type_t::i4: - return {4, false, true, true, "int4_t", "i4"}; - case ov::element::Type_t::i8: - return {8, false, true, true, "int8_t", "i8"}; - case ov::element::Type_t::i16: - return {16, false, true, false, "int16_t", "i16"}; - case ov::element::Type_t::i32: - return {32, false, true, true, "int32_t", "i32"}; - case ov::element::Type_t::i64: - return {64, false, true, false, "int64_t", "i64"}; - case ov::element::Type_t::u1: - return {1, false, false, false, "uint1_t", "u1"}; - case ov::element::Type_t::u2: - return {2, false, false, false, "uint2_t", "u2"}; - case ov::element::Type_t::u3: - return {3, false, false, false, "uint3_t", "u3"}; - case ov::element::Type_t::u4: - return {4, false, false, false, "uint4_t", "u4"}; - case ov::element::Type_t::u6: - return {6, false, false, false, "uint6_t", "u6"}; - case ov::element::Type_t::u8: - return {8, false, false, true, "uint8_t", "u8"}; - case ov::element::Type_t::u16: - return {16, false, false, false, "uint16_t", "u16"}; - case ov::element::Type_t::u32: - return {32, false, false, false, "uint32_t", "u32"}; - case ov::element::Type_t::u64: - return {64, false, false, false, "uint64_t", "u64"}; - case ov::element::Type_t::nf4: - return {4, false, false, true, "nfloat4", "nf4"}; - case ov::element::Type_t::f8e4m3: - return {8, true, true, true, "f8e4m3", "f8e4m3"}; - case ov::element::Type_t::f8e5m2: - return {8, true, true, true, "f8e5m2", "f8e5m2"}; - case ov::element::Type_t::string: - return {8 * sizeof(std::string), false, false, false, "string", "string"}; - case ov::element::Type_t::f4e2m1: - return {4, true, true, true, "f4e2m1", "f4e2m1"}; - case ov::element::Type_t::f8e8m0: - return {8, true, true, true, "f8e8m0", "f8e8m0"}; - default: - OPENVINO_THROW("ov::element::Type_t not supported: ", type); + constexpr bool is_valid() const { + return m_cname != nullptr && m_type_name != nullptr; } }; +; + +constexpr TypeInfo type_info(size_t bitwidth, + bool is_real, + bool is_signed, + bool is_quantized, + const char* cname, + const char* type_name) { + return {bitwidth, is_real, is_signed, is_quantized, cname, type_name, nullptr, 0}; +} + +template +constexpr TypeInfo type_info(size_t bitwidth, + bool is_real, + bool is_signed, + bool is_quantized, + const char* cname, + const char* type_name, + const Array& aliases) { + return {bitwidth, is_real, is_signed, is_quantized, cname, type_name, aliases.data(), aliases.size()}; +} + +constexpr auto undefined_aliases = util::make_array("UNSPECIFIED"); +constexpr auto boolean_aliases = util::make_array("BOOL"); +constexpr auto bf16_aliases = util::make_array("BF16"); +constexpr auto f16_aliases = util::make_array("FP16"); +constexpr auto f32_aliases = util::make_array("FP32"); +constexpr auto f64_aliases = util::make_array("FP64"); +constexpr auto i4_aliases = util::make_array("I4"); +constexpr auto i8_aliases = util::make_array("I8"); +constexpr auto i16_aliases = util::make_array("I16"); +constexpr auto i32_aliases = util::make_array("I32"); +constexpr auto i64_aliases = util::make_array("I64"); +constexpr auto u1_aliases = util::make_array("U1", "bin", "BIN"); +constexpr auto u2_aliases = util::make_array("U2"); +constexpr auto u3_aliases = util::make_array("U3"); +constexpr auto u4_aliases = util::make_array("U4"); +constexpr auto u6_aliases = util::make_array("U6"); +constexpr auto u8_aliases = util::make_array("U8"); +constexpr auto u16_aliases = util::make_array("U16"); +constexpr auto u32_aliases = util::make_array("U32"); +constexpr auto u64_aliases = util::make_array("U64"); +constexpr auto nf4_aliases = util::make_array("NF4"); +constexpr auto f8e4m3_aliases = util::make_array("F8E4M3"); +constexpr auto f8e5m2_aliases = util::make_array("F8E5M2"); +constexpr auto string_aliases = util::make_array("STRING"); +constexpr auto f4e2m1_aliases = util::make_array("F4E2M1"); +constexpr auto f8e8m0_aliases = util::make_array("F8E8M0"); + +static constexpr std::array types_info = { + type_info(std::numeric_limits::max(), + false, + false, + false, + "undefined", + "undefined", + undefined_aliases), // undefined + type_info(0, false, false, false, "dynamic", "dynamic"), // dynamic + type_info(8, false, true, false, "char", "boolean", boolean_aliases), // boolean + type_info(16, true, true, false, "bfloat16", "bf16", bf16_aliases), // bf16 + type_info(16, true, true, false, "float16", "f16", f16_aliases), // f16 + type_info(32, true, true, false, "float", "f32", f32_aliases), // f32 + type_info(64, true, true, false, "double", "f64", f64_aliases), // f64 + type_info(4, false, true, true, "int4_t", "i4", i4_aliases), // i4 + type_info(8, false, true, true, "int8_t", "i8", i8_aliases), // i8 + type_info(16, false, true, false, "int16_t", "i16", i16_aliases), // i16 + type_info(32, false, true, true, "int32_t", "i32", i32_aliases), // i32 + type_info(64, false, true, false, "int64_t", "i64", i64_aliases), // i64 + type_info(1, false, false, false, "uint1_t", "u1", u1_aliases), // u1 + type_info(2, false, false, false, "uint2_t", "u2", u2_aliases), // u2 + type_info(3, false, false, false, "uint3_t", "u3", u3_aliases), // u3 + type_info(4, false, false, false, "uint4_t", "u4", u4_aliases), // u4 + type_info(6, false, false, false, "uint6_t", "u6", u6_aliases), // u6 + type_info(8, false, false, true, "uint8_t", "u8", u8_aliases), // u8 + type_info(16, false, false, false, "uint16_t", "u16", u16_aliases), // u16 + type_info(32, false, false, false, "uint32_t", "u32", u32_aliases), // u32 + type_info(64, false, false, false, "uint64_t", "u64", u64_aliases), // u64 + type_info(4, false, false, true, "nfloat4", "nf4", nf4_aliases), // nf4 + type_info(8, true, true, true, "f8e4m3", "f8e4m3", f8e4m3_aliases), // f8e4m3 + type_info(8, true, true, true, "f8e5m2", "f8e5m2", f8e5m2_aliases), // f8e5m2 + type_info(8 * sizeof(std::string), false, false, false, "string", "string", string_aliases), // string + type_info(4, true, true, true, "f4e2m1", "f4e2m1", f4e2m1_aliases), // f4e2m1 + type_info(8, true, true, true, "f8e8m0", "f8e8m0", f8e8m0_aliases) // f8e8m0 +}; -ov::element::Type type_from_string(const std::string& type) { - if (type == "f16" || type == "FP16") { - return ::ov::element::Type(::ov::element::Type_t::f16); - } else if (type == "f32" || type == "FP32") { - return ::ov::element::Type(::ov::element::Type_t::f32); - } else if (type == "bf16" || type == "BF16") { - return ::ov::element::Type(::ov::element::Type_t::bf16); - } else if (type == "f64" || type == "FP64") { - return ::ov::element::Type(::ov::element::Type_t::f64); - } else if (type == "i4" || type == "I4") { - return ::ov::element::Type(::ov::element::Type_t::i4); - } else if (type == "i8" || type == "I8") { - return ::ov::element::Type(::ov::element::Type_t::i8); - } else if (type == "i16" || type == "I16") { - return ::ov::element::Type(::ov::element::Type_t::i16); - } else if (type == "i32" || type == "I32") { - return ::ov::element::Type(::ov::element::Type_t::i32); - } else if (type == "i64" || type == "I64") { - return ::ov::element::Type(::ov::element::Type_t::i64); - } else if (type == "u1" || type == "U1" || type == "BIN" || type == "bin") { - return ::ov::element::Type(::ov::element::Type_t::u1); - } else if (type == "u2" || type == "U2") { - return ::ov::element::Type(::ov::element::Type_t::u2); - } else if (type == "u3" || type == "U3") { - return ::ov::element::Type(::ov::element::Type_t::u3); - } else if (type == "u4" || type == "U4") { - return ::ov::element::Type(::ov::element::Type_t::u4); - } else if (type == "u6" || type == "U6") { - return ::ov::element::Type(::ov::element::Type_t::u6); - } else if (type == "u8" || type == "U8") { - return ::ov::element::Type(::ov::element::Type_t::u8); - } else if (type == "u16" || type == "U16") { - return ::ov::element::Type(::ov::element::Type_t::u16); - } else if (type == "u32" || type == "U32") { - return ::ov::element::Type(::ov::element::Type_t::u32); - } else if (type == "u64" || type == "U64") { - return ::ov::element::Type(::ov::element::Type_t::u64); - } else if (type == "boolean" || type == "BOOL") { - return ::ov::element::Type(::ov::element::Type_t::boolean); - } else if (type == "string" || type == "STRING") { - return ::ov::element::Type(::ov::element::Type_t::string); - } else if (type == "undefined" || type == "UNSPECIFIED") { - return ::ov::element::Type(::ov::element::Type_t::undefined); - } else if (type == "dynamic") { - return ::ov::element::Type(::ov::element::Type_t::dynamic); - } else if (type == "nf4" || type == "NF4") { - return ::ov::element::Type(::ov::element::Type_t::nf4); - } else if (type == "f8e4m3" || type == "F8E4M3") { - return ::ov::element::Type(::ov::element::Type_t::f8e4m3); - } else if (type == "f8e5m2" || type == "F8E5M2") { - return ::ov::element::Type(::ov::element::Type_t::f8e5m2); - } else if (type == "f4e2m1" || type == "F4E2M1") { - return ::ov::element::Type(::ov::element::Type_t::f4e2m1); - } else if (type == "f8e8m0" || type == "F8E8M0") { - return ::ov::element::Type(::ov::element::Type_t::f8e8m0); - } else { - OPENVINO_THROW("Incorrect type: ", type); +constexpr bool validate_types_info(decltype(types_info)& info, size_t i = 0) { + return i >= info.size() ? true : info[i].is_valid() ? validate_types_info(info, i + 1) : false; +} + +static_assert(validate_types_info(types_info), "Some entries of type_info are invalid."); + +constexpr bool is_valid_type_idx(size_t idx) { + return idx < types_info.size(); +} + +size_t type_idx_for(const std::string& type_name) { + size_t type_idx = 0; + for (; is_valid_type_idx(type_idx); ++type_idx) { + if (types_info[type_idx].has_name(type_name)) { + break; + } } + return type_idx; } + +const TypeInfo& get_type_info(Type_t type) { + const auto type_idx = idx(type); + OPENVINO_ASSERT(is_valid_type_idx(type_idx), "Type_t not supported: ", type_idx); + return types_info[type_idx]; +} + +Type type_from_string(const std::string& type) { + const auto type_idx = type_idx_for(type); + OPENVINO_ASSERT(is_valid_type_idx(type_idx), "Unsupported element type: ", type); + return {static_cast(type_idx)}; +} + +// generate known types automatically +static constexpr auto known_types = [] { + std::array types; + for (size_t idx = 1, i = 0; i < types.size(); ++idx, ++i) { + types[i] = Type{static_cast(idx)}; + } + return types; +}(); } // namespace -std::vector ov::element::Type::get_known_types() { - std::vector rc = { - &ov::element::dynamic, &ov::element::boolean, &ov::element::bf16, &ov::element::f16, &ov::element::f32, - &ov::element::f64, &ov::element::i4, &ov::element::i8, &ov::element::i16, &ov::element::i32, - &ov::element::i64, &ov::element::u1, &ov::element::u2, &ov::element::u3, &ov::element::u4, - &ov::element::u6, &ov::element::u8, &ov::element::u16, &ov::element::u32, &ov::element::u64, - &ov::element::nf4, &ov::element::f8e4m3, &ov::element::f8e5m2, &ov::element::string, &ov::element::f4e2m1, - &ov::element::f8e8m0}; - return rc; +std::vector Type::get_known_types() { + std::vector result(known_types.size()); + for (size_t i = 0; i < known_types.size(); ++i) { + result[i] = &known_types[i]; + } + return result; } -ov::element::Type::Type(const std::string& type) : Type(type_from_string(type)) {} +Type::Type(const std::string& type) : Type(type_from_string(type)) {} -std::string ov::element::Type::c_type_string() const { +std::string Type::c_type_string() const { return get_type_info(m_type).m_cname; } -size_t ov::element::Type::size() const { +size_t Type::size() const { return (bitwidth() + 7) >> 3; } -size_t ov::element::Type::hash() const { +size_t Type::hash() const { return static_cast(m_type); } -std::string ov::element::Type::get_type_name() const { +std::string Type::get_type_name() const { return to_string(); } -std::string ov::element::Type::to_string() const { +std::string Type::to_string() const { return get_type_info(m_type).m_type_name; } -namespace ov { -namespace element { -template <> -Type from() { - return Type_t::boolean; -} -template <> -Type from() { - return Type_t::boolean; -} -template <> -Type from() { - return Type_t::f16; -} -template <> -Type from() { - return Type_t::f32; -} -template <> -Type from() { - return Type_t::f64; -} -template <> -Type from() { - return Type_t::i8; -} -template <> -Type from() { - return Type_t::i16; -} -template <> -Type from() { - return Type_t::i32; -} -template <> -Type from() { - return Type_t::i64; -} -template <> -Type from() { - return Type_t::u8; -} -template <> -Type from() { - return Type_t::u16; -} -template <> -Type from() { - return Type_t::u32; -} -template <> -Type from() { - return Type_t::u64; -} -template <> -Type from() { - return Type_t::bf16; -} -template <> -Type from() { - return Type_t::f8e4m3; -} -template <> -Type from() { - return Type_t::f8e5m2; -} -template <> -Type from() { - return Type_t::string; -} -template <> -Type from() { - return Type_t::f4e2m1; -} -template <> -Type from() { - return Type_t::f8e8m0; -} - Type fundamental_type_for(const Type& type) { switch (type) { case Type_t::boolean: @@ -322,44 +259,24 @@ Type fundamental_type_for(const Type& type) { } } -} // namespace element -} // namespace ov - -std::ostream& ov::element::operator<<(std::ostream& out, const ov::element::Type& obj) { +std::ostream& operator<<(std::ostream& out, const Type& obj) { return out << obj.to_string(); } -std::istream& ov::element::operator>>(std::istream& in, ov::element::Type& obj) { - const std::unordered_map legacy = { - {"BOOL", ov::element::boolean}, {"BF16", ov::element::bf16}, {"I4", ov::element::i4}, - {"I8", ov::element::i8}, {"I16", ov::element::i16}, {"I32", ov::element::i32}, - {"I64", ov::element::i64}, {"U4", ov::element::u4}, {"U8", ov::element::u8}, - {"U16", ov::element::u16}, {"U32", ov::element::u32}, {"U64", ov::element::u64}, - {"FP32", ov::element::f32}, {"FP64", ov::element::f64}, {"FP16", ov::element::f16}, - {"BIN", ov::element::u1}, {"NF4", ov::element::nf4}, {"F8E4M3", ov::element::f8e4m3}, - {"F8E5M2", ov::element::f8e5m2}, {"STRING", ov::element::string}, {"F4E2M1", ov::element::f4e2m1}, - {"F8E8M0", ov::element::f8e8m0}}; +std::istream& operator>>(std::istream& in, Type& obj) { std::string str; in >> str; - auto it_legacy = legacy.find(str); - if (it_legacy != legacy.end()) { - obj = it_legacy->second; - return in; - } - for (auto&& type : Type::get_known_types()) { - if (type->to_string() == str) { - obj = *type; - break; - } + if (const auto type_idx = type_idx_for(str); is_valid_type_idx(type_idx)) { + obj = {static_cast(type_idx)}; } return in; } -bool ov::element::Type::compatible(const ov::element::Type& t) const { +bool Type::compatible(const Type& t) const { return (is_dynamic() || t.is_dynamic() || *this == t); } -bool ov::element::Type::merge(ov::element::Type& dst, const ov::element::Type& t1, const ov::element::Type& t2) { +bool Type::merge(Type& dst, const Type& t1, const Type& t2) { if (t1.is_dynamic()) { dst = t2; return true; @@ -374,69 +291,30 @@ bool ov::element::Type::merge(ov::element::Type& dst, const ov::element::Type& t } } -bool ov::element::Type::is_static() const { +bool Type::is_static() const { return get_type_info(m_type).m_bitwidth != 0; } -bool ov::element::Type::is_real() const { +bool Type::is_real() const { return get_type_info(m_type).m_is_real; } -bool ov::element::Type::is_integral_number() const { - return is_integral() && (m_type != ov::element::boolean); +bool Type::is_integral_number() const { + return is_integral() && (m_type != boolean); } -bool ov::element::Type::is_signed() const { +bool Type::is_signed() const { return get_type_info(m_type).m_is_signed; } -bool ov::element::Type::is_quantized() const { +bool Type::is_quantized() const { return get_type_info(m_type).m_is_quantized; } -size_t ov::element::Type::bitwidth() const { +size_t Type::bitwidth() const { return get_type_info(m_type).m_bitwidth; } - -inline size_t compiler_byte_size(ov::element::Type_t et) { - switch (et) { -#define ET_CASE(et) \ - case ov::element::Type_t::et: \ - return sizeof(ov::element_type_traits::value_type); - ET_CASE(boolean); - ET_CASE(bf16); - ET_CASE(f16); - ET_CASE(f32); - ET_CASE(f64); - ET_CASE(i4); - ET_CASE(i8); - ET_CASE(i16); - ET_CASE(i32); - ET_CASE(i64); - ET_CASE(u1); - ET_CASE(u2); - ET_CASE(u3); - ET_CASE(u4); - ET_CASE(u6); - ET_CASE(u8); - ET_CASE(u16); - ET_CASE(u32); - ET_CASE(u64); - ET_CASE(nf4); - ET_CASE(f8e4m3); - ET_CASE(f8e5m2); - ET_CASE(string); - ET_CASE(f4e2m1); - ET_CASE(f8e8m0); -#undef ET_CASE - case ov::element::Type_t::undefined: - return 0; - case ov::element::Type_t::dynamic: - return 0; - } - - OPENVINO_THROW("compiler_byte_size: Unsupported value of ov::element::Type_t: ", static_cast(et)); -} +} // namespace ov::element namespace ov { template <> diff --git a/src/inference/src/dev/core_impl.cpp b/src/inference/src/dev/core_impl.cpp index 0cad1840e5d1a8..7e2a0a8b4be441 100644 --- a/src/inference/src/dev/core_impl.cpp +++ b/src/inference/src/dev/core_impl.cpp @@ -38,18 +38,6 @@ ov::ICore::~ICore() = default; -namespace ov { -namespace util { -template -constexpr std::array< - typename std::conditional::value, typename std::common_type::type, T>::type, - sizeof...(Args)> -make_array(Args&&... args) { - return {std::forward(args)...}; -} -} // namespace util -} // namespace ov - namespace { #ifdef PROXY_PLUGIN_ENABLED diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/base/utils/ranges.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/base/utils/ranges.hpp index 3805fde5ce9bfb..a383fc2b7df220 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/base/utils/ranges.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/base/utils/ranges.hpp @@ -34,16 +34,16 @@ struct Range { max_known_port = std::max(static_cast(max_known_port), 1); for (size_t port = 0; port < max_known_port; port++) { std::map type_map; - for (auto& type : ov::element::Type::get_known_types()) { - ov::test::utils::InputGenerateData new_range = rangeByType.get_range(*type); - if (type->is_real() && port < real_port_ranges.size()) { + for (const auto& type : get_known_types()) { + ov::test::utils::InputGenerateData new_range = rangeByType.get_range(type); + if (type.is_real() && port < real_port_ranges.size()) { new_range.correct_range(real_port_ranges.at(port)); new_range.input_attribute = real_port_ranges.at(port).input_attribute; - } else if (type->is_integral() && port < int_port_ranges.size()) { + } else if (type.is_integral() && port < int_port_ranges.size()) { new_range.correct_range(int_port_ranges.at(port)); new_range.input_attribute = int_port_ranges.at(port).input_attribute; } - type_map[*type] = new_range; + type_map[type] = new_range; } data.push_back(type_map); } diff --git a/src/tests/test_utils/common_test_utils/include/common_test_utils/type_ranges.hpp b/src/tests/test_utils/common_test_utils/include/common_test_utils/type_ranges.hpp index c84b58066387f3..7dc5841869a493 100644 --- a/src/tests/test_utils/common_test_utils/include/common_test_utils/type_ranges.hpp +++ b/src/tests/test_utils/common_test_utils/include/common_test_utils/type_ranges.hpp @@ -15,6 +15,21 @@ namespace ov { namespace test { namespace utils { +static const std::vector& get_known_types() { + static const auto known_types = [] { + using namespace ov::element; + constexpr size_t enum_count = static_cast>(Type_t::f8e8m0) - 1; + + std::vector types(enum_count); + for (size_t idx = 1, i = 0; i < types.size(); ++idx, ++i) { + types[i] = Type{static_cast(idx)}; + } + return types; + }(); + + return known_types; +} + static ov::test::utils::InputGenerateData get_range_by_type( ov::element::Type elemType, uint32_t max_range_limit = testing::internal::Random::kMaxRange) { @@ -110,8 +125,8 @@ struct RangeByType { std::map data; RangeByType() { - for (auto& type : ov::element::Type::get_known_types()) { - data[*type] = get_range_by_type(*type); + for (const auto& type : get_known_types()) { + data[type] = get_range_by_type(type); } }