diff --git a/src/plugins/intel_cpu/src/cpu_memory.cpp b/src/plugins/intel_cpu/src/cpu_memory.cpp index 5e749121ecda51..5dfc0659353845 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.cpp +++ b/src/plugins/intel_cpu/src/cpu_memory.cpp @@ -9,6 +9,7 @@ #include "memory_desc/cpu_memory_desc_utils.h" #include "nodes/common/cpu_memcpy.h" #include "nodes/reorder.h" +#include "utils/bfloat16.hpp" #include "utils/debug_capabilities.h" #if defined(__linux__) # include /* Definition of SYS_* constants */ @@ -31,19 +32,26 @@ BlockedMemoryDescPtr IMemory::getDescWithType() const { } namespace { -inline void setSubnormalsToZero(float* data, size_t size) { +inline void setSubnormalsToZeroAndbf16Saturation(float* data, size_t size, bool ftz, bool bf16saturation) { uint32_t* u32data = reinterpret_cast(data); + float* floatdata = reinterpret_cast(data); for (size_t i = 0; i < size; ++i) { - if ((u32data[i] & (0xFF << 23)) == 0) { + if (ftz && ((u32data[i] & (0xFF << 23)) == 0)) { u32data[i] = 0; + } else if (bf16saturation && !std::isnan(floatdata[i]) && !std::isinf(floatdata[i])) { + floatdata[i] = (floatdata[i] < static_cast(std::numeric_limits::lowest())) + ? static_cast(std::numeric_limits::lowest()) + : (floatdata[i] > static_cast(std::numeric_limits::max())) + ? static_cast(std::numeric_limits::max()) + : floatdata[i]; } } } -void transferData(const IMemory& src, const IMemory& dst, bool ftz) { +void transferData(const IMemory& src, const IMemory& dst, bool ftz, bool bf16saturation) { node::Reorder::reorderData(src, dst); - if (!ftz) { + if (!ftz && !bf16saturation) { return; } if (src.getDesc().getPrecision() != ov::element::f32 || dst.getDesc().getPrecision() != ov::element::f32) { @@ -63,7 +71,7 @@ void transferData(const IMemory& src, const IMemory& dst, bool ftz) { // actual FTZ auto* memData = static_cast(dst.getData()); memData += offset; - setSubnormalsToZero(memData, dst.getSize() / sizeof(float)); + setSubnormalsToZeroAndbf16Saturation(memData, dst.getSize() / sizeof(float), ftz, bf16saturation); } } // namespace @@ -126,11 +134,11 @@ void Memory::create(MemoryDescPtr desc, const void* data, bool pads_zeroing) { } } -void Memory::load(const IMemory& src, bool ftz) const { +void Memory::load(const IMemory& src, bool ftz, bool bf16saturation) const { if (src.getDesc().getPrecision() == element::string) { OPENVINO_THROW("[CPU] Memory object cannot load string data."); } - transferData(src, *this, ftz); + transferData(src, *this, ftz, bf16saturation); } void Memory::nullify() { @@ -272,12 +280,12 @@ StringMemory::StringMemory(dnnl::engine engine, MemoryDescPtr desc, const void* } } -void StringMemory::load(const IMemory& src, bool ftz) const { +void StringMemory::load(const IMemory& src, bool ftz, bool bf16saturation) const { if (src.getDesc().getPrecision() != element::string) { OPENVINO_THROW("[CPU] String memory cannot load a non-string object."); } - transferData(src, *this, false); + transferData(src, *this, false, false); } void* StringMemory::getData() const { @@ -471,11 +479,11 @@ void StaticMemory::redefineDesc(MemoryDescPtr desc) { OPENVINO_THROW("Unexpected: Memory descriptor may not be modified in StaticMemory object"); } -void StaticMemory::load(const IMemory& src, bool ftz) const { +void StaticMemory::load(const IMemory& src, bool ftz, bool bf16saturation) const { if (src.getDesc().getPrecision() == element::string) { OPENVINO_THROW("[CPU] StaticMemory cannot load string data."); } - transferData(src, *this, ftz); + transferData(src, *this, ftz, bf16saturation); } MemoryBlockPtr StaticMemory::getMemoryBlock() const { diff --git a/src/plugins/intel_cpu/src/cpu_memory.h b/src/plugins/intel_cpu/src/cpu_memory.h index 1b1b4debe4fcc4..f22a87cd157ef0 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.h +++ b/src/plugins/intel_cpu/src/cpu_memory.h @@ -188,7 +188,7 @@ class IMemory { // Caution!!! This action invalidates the previous data layout. The old data may become unreachable. virtual void redefineDesc(MemoryDescPtr desc) = 0; - virtual void load(const IMemory& src, bool ftz = true) const = 0; + virtual void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const = 0; virtual MemoryBlockPtr getMemoryBlock() const = 0; @@ -260,7 +260,7 @@ class StaticMemory final : public IMemory { // Always throws since a static memory descriptor should not be modified void redefineDesc(MemoryDescPtr desc) override; - void load(const IMemory& src, bool ftz = true) const override; + void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const override; MemoryBlockPtr getMemoryBlock() const override; @@ -315,7 +315,7 @@ class Memory : public IMemory { void redefineDesc(MemoryDescPtr desc) override; - void load(const IMemory& src, bool ftz = true) const override; + void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const override; void nullify() override; dnnl::engine getEngine() const { @@ -421,7 +421,7 @@ class StringMemory : public IMemory { void redefineDesc(MemoryDescPtr desc) override; - void load(const IMemory& src, bool ftz = false) const override; + void load(const IMemory& src, bool ftz = false, bool bf16saturation = false) const override; MemoryBlockPtr getMemoryBlock() const override; diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp index 78ad3b04aa06b1..c47a5772a2f46d 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp @@ -18,7 +18,11 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter { conversion_mode mode = conversion_mode::default_mode) : jit_emitter(host, host_isa, exec_prc), mode_(mode) { - prepare_table(); + if ((!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) && + !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2)) || + mode_ == conversion_mode::saturation_mode) { + prepare_table(); + } } size_t get_inputs_num() const override { diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.cpp b/src/plugins/intel_cpu/src/nodes/eltwise.cpp index c13f22b0d9b76a..9e6b8ab890adda 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/eltwise.cpp @@ -2895,15 +2895,9 @@ void Eltwise::prepareParams() { // FP32 constant inputs may contain values out of BF16 representable range. In case output precision is BF16 we // choose "saturation" mode for fp32->bf16 conversion procedure to prevent getting -Inf/+Inf values in the - // outputs. Since "saturation" conversion is more time consuming, better solution would be to clamp constants on - // compilation stage (ticket: 159589). + // outputs. Since "saturation" conversion during kernel runtime is more time consuming, current solution is + // clamp constants on compilation stage. key.doOutputSaturation = false; - for (size_t i = 0; i < getParentEdges().size(); i++) { - if (getParentEdgeAt(i)->getParent()->isConstant()) { - key.doOutputSaturation = true; - break; - } - } auto cache = context->getParamsCache(); auto result = cache->getOrCreate(key, buildExecutor); diff --git a/src/plugins/intel_cpu/src/nodes/input.cpp b/src/plugins/intel_cpu/src/nodes/input.cpp index f812da7ca01159..eaaf84677ca891 100644 --- a/src/plugins/intel_cpu/src/nodes/input.cpp +++ b/src/plugins/intel_cpu/src/nodes/input.cpp @@ -23,18 +23,18 @@ namespace node { #if defined(OPENVINO_ARCH_X86_64) namespace { -struct jit_has_subnormals_base : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_has_subnormals_base) +struct jit_has_special_value_base : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_has_special_value_base) typedef struct { const float* src; const size_t count; - bool hasSubnormals; + bool hasTargetValues; } args_t; typedef void (*fn_t)(const args_t*); - jit_has_subnormals_base() : jit_generator(jit_name()) { + jit_has_special_value_base() : jit_generator(jit_name()) { jit_ker_ = nullptr; } @@ -110,8 +110,35 @@ struct jit_has_subnormals_base : public jit_generator { uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 } + void check_bf16_saturations(const Xbyak::Reg64& src, + const Xbyak::Ymm& bf16_max_mask, + const Xbyak::Ymm& bf16_min_mask) { + auto a = ymm1; + auto b = ymm2; + auto c = ymm3; + vmovdqu(a, yword[src]); // load 8 floats + vcmpps(b, a, bf16_max_mask, 0x1e); // b = (a > bf16_max) ? 1 : 0 + vcmpps(c, a, bf16_min_mask, 0x11); // c = (a < bf16_min) ? 1 : 0 + vorps(b, b, c); // b = b | c + vptest(b, b); // if (b != 0) CF = 1 else CF = 0 + } + + void check_bf16_saturations(const Xbyak::Reg64& src, + const Xbyak::Xmm& bf16_max_mask, + const Xbyak::Xmm& bf16_min_mask) { + auto a = xmm1; + auto b = xmm2; + auto c = xmm3; + + uni_vmovdqu(a, xword[src]); // load 4 floats + uni_vcmpps(b, a, bf16_max_mask, 0x1e); // b = (a > bf16_max) ? 1 : 0 + uni_vcmpps(c, a, bf16_max_mask, 0x11); // c = (a < bf16_min) ? 1 : 0 + uni_vorps(b, b, c); // b = b | c + uni_vtestps(b, b); // if (b != 0) CF = 1 else CF = 0 + } + protected: - Label exit, has_subnormals, no_subnormals; + Label exit, has_target_values, no_target_values; const Reg64& reg_src = rax; const Reg64& reg_dst = rbx; @@ -121,16 +148,35 @@ struct jit_has_subnormals_base : public jit_generator { static const uint32_t exponent_mask_data[8]; static const uint32_t mantissa_mask_data[8]; + static const float bf16_max_mask_data[8]; + static const float bf16_min_mask_data[8]; }; -const uint32_t jit_has_subnormals_base::exponent_mask_data[8] = +const uint32_t jit_has_special_value_base::exponent_mask_data[8] = {0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000}; -const uint32_t jit_has_subnormals_base::mantissa_mask_data[8] = +const uint32_t jit_has_special_value_base::mantissa_mask_data[8] = {0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff}; +const float jit_has_special_value_base::bf16_max_mask_data[8] = {std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max()}; + +const float jit_has_special_value_base::bf16_min_mask_data[8] = {std::numeric_limits::lowest(), + std::numeric_limits::lowest(), + std::numeric_limits::lowest(), + std::numeric_limits::lowest(), + std::numeric_limits::lowest(), + std::numeric_limits::lowest(), + std::numeric_limits::lowest(), + std::numeric_limits::lowest()}; template -struct jit_has_subnormals : public jit_has_subnormals_base { +struct jit_has_subnormals : public jit_has_special_value_base { using Vmm = typename dnnl::impl::utils::conditional::type; const Vmm rmm4 = Vmm(4); @@ -150,7 +196,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base { // Get arguments addresses mov(reg_src, ptr[param1 + offsetof(args_t, src)]); - lea(reg_dst, ptr[param1 + offsetof(args_t, hasSubnormals)]); + lea(reg_dst, ptr[param1 + offsetof(args_t, hasTargetValues)]); mov(reg_sz, ptr[param1 + offsetof(args_t, count)]); // Initialize necessary consts @@ -167,7 +213,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base { foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) { check_subnormals(reg_src, exponent_mask, mantissa_mask, zero); - jnc(has_subnormals); + jnc(has_target_values); add(reg_src, sizeof(float) * vlen); }) ; @@ -186,16 +232,16 @@ struct jit_has_subnormals : public jit_has_subnormals_base { copy_floats(r8, reg_src, reg_sz); check_subnormals(r8, exponent_mask, mantissa_mask, zero); - jc(no_subnormals); + jc(no_target_values); add(rsp, vlen * sizeof(float)); - L(has_subnormals); + L(has_target_values); mov(rax, 1); mov(byte[reg_dst], al); jmp(exit); - L(no_subnormals); + L(no_target_values); add(rsp, vlen * sizeof(float)); L(exit); @@ -203,8 +249,81 @@ struct jit_has_subnormals : public jit_has_subnormals_base { postamble(); } }; +template +struct jit_has_bf16_overflows : public jit_has_special_value_base { + using Vmm = typename dnnl::impl::utils::conditional::type; + + const Vmm rmm4 = Vmm(4); + const Vmm rmm5 = Vmm(5); + const Vmm rmm6 = Vmm(6); + const int length = isa == sse41 ? 4 : 8; + + void generate() override final { // NOLINT + size_t const vlen = length; + const int sh_bits = std::ilogb(vlen); + + auto zero = rmm4; + auto bf16_max_mask = rmm5; + auto bf16_min_mask = rmm6; + + preamble(); + + // Get arguments addresses + mov(reg_src, ptr[param1 + offsetof(args_t, src)]); + lea(reg_dst, ptr[param1 + offsetof(args_t, hasTargetValues)]); + mov(reg_sz, ptr[param1 + offsetof(args_t, count)]); + + // Initialize necessary consts + uni_vpxor(zero, zero, zero); + mov(reg_mask_addr, (size_t)bf16_max_mask_data); + uni_vmovdqu(bf16_max_mask, ptr[reg_mask_addr]); + mov(reg_mask_addr, (size_t)bf16_min_mask_data); + uni_vmovdqu(bf16_min_mask, ptr[reg_mask_addr]); + + // Main loop + xor_(reg_idx, reg_idx); + mov(r8, reg_sz); + shr(r8, sh_bits); + + foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) { + check_bf16_saturations(reg_src, bf16_max_mask, bf16_min_mask); + jnz(has_target_values, T_NEAR); + add(reg_src, sizeof(float) * vlen); + }) + ; + + // Tail + shl(reg_idx, sh_bits); + sub(reg_sz, reg_idx); + test(reg_sz, reg_sz); + jz(exit); + + // use space on stack for 4 or 8 floats + sub(rsp, vlen * sizeof(float)); + mov(r8, rsp); -jit_has_subnormals_base::fn_t jit_has_subnormals_function() { + uni_vmovdqu(ptr[r8], zero); + + copy_floats(r8, reg_src, reg_sz); + check_bf16_saturations(r8, bf16_max_mask, bf16_min_mask); + jz(no_target_values, T_NEAR); + add(rsp, vlen * sizeof(float)); + + L(has_target_values); + + mov(rax, 1); + mov(byte[reg_dst], al); + jmp(exit); + + L(no_target_values); + add(rsp, vlen * sizeof(float)); + + L(exit); + + postamble(); + } +}; +jit_has_special_value_base::fn_t jit_has_subnormals_function() { if (mayiuse(cpu_isa_t::avx2)) { static jit_has_subnormals generator; static auto fn = generator.get(); @@ -216,6 +335,18 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() { } return nullptr; } +jit_has_special_value_base::fn_t jit_has_bf16_overflows_function() { + if (mayiuse(cpu_isa_t::avx2)) { + static jit_has_bf16_overflows generator; + static auto fn = generator.get(); + return fn; + } else if (mayiuse(cpu_isa_t::sse41)) { + static jit_has_bf16_overflows generator; + static auto fn = generator.get(); + return fn; + } + return nullptr; +} } // namespace #endif @@ -262,6 +393,90 @@ void Input::cloneBlobIfRequired() { needFlushDenormalsToZero = false; } + // The presence of subnormals is better to determined at IR read time. + auto checkSubnormalsAndBF16Overflows = [&](bool& has_subnormals, bool& has_bf16_overflows) { + if (prec == ov::element::f32) { + uint32_t const* u32data = m_constOp->get_data_ptr(); + float const* f32data = m_constOp->get_data_ptr(); + + if (!size) + return; + + // Only LLMs scalar constant nodes with bf16 inferencePrecision need to be checked for saturation + const bool do_bf16_saturation_check = + (context->getConfig().inferencePrecision == ov::element::bf16 && size == 1) ? true : false; + +#if defined(OPENVINO_ARCH_X86_64) + auto fn = jit_has_subnormals_function(); + auto fn_bf16_check = jit_has_bf16_overflows_function(); + if (fn && fn_bf16_check) { + static const size_t batch_size = 2048; + const size_t iterations_num = size / batch_size + 1; + + volatile bool has_subnormals_local = false; + volatile bool has_bf16_overflows_local = false; + if (needFlushDenormalsToZero) { + parallel_for(iterations_num, [&](int n) { + auto ptr = u32data + n * batch_size; + const jit_has_special_value_base::args_t args1 = { + reinterpret_cast(ptr), + std::min(batch_size, (size_t)(u32data + size - ptr)), + false}; + + fn(&args1); + + if (args1.hasTargetValues) + has_subnormals_local = true; + }); + } + + if (do_bf16_saturation_check) { + parallel_for(iterations_num, [&](int n) { + auto ptr2 = f32data + n * batch_size; + const jit_has_special_value_base::args_t args2 = { + reinterpret_cast(ptr2), + std::min(batch_size, (size_t)(f32data + size - ptr2)), + false}; + + fn_bf16_check(&args2); + + if (args2.hasTargetValues) + has_bf16_overflows_local = true; + }); + } + + has_subnormals = has_subnormals_local; + has_bf16_overflows = has_bf16_overflows_local; + + return; + } +#endif + + uint32_t mantissaMask = 0x007fffff; + uint32_t exponentMask = 0x7f800000; + const float bf16_max = std::numeric_limits::max(); + for (size_t i = 0; i < size; ++i) { + if (needFlushDenormalsToZero && (u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) { + has_subnormals = true; + } + + if (do_bf16_saturation_check && (f32data[i] < -bf16_max || f32data[i] > bf16_max)) { + has_bf16_overflows = true; + } + + if ((!needFlushDenormalsToZero || has_subnormals) && + (!do_bf16_saturation_check || has_bf16_overflows)) { + return; + } + } + } + }; + + bool has_subnormals = false; + bool has_bf16_overflows = false; + + checkSubnormalsAndBF16Overflows(has_subnormals, has_bf16_overflows); + auto cloneBlob = [&, this]() { MemoryPtr memory; @@ -294,7 +509,7 @@ void Input::cloneBlobIfRequired() { } else { ptr = std::make_shared(getEngine(), memDesc); } - ptr->load(*memory.get(), needFlushDenormalsToZero); + ptr->load(*memory.get(), has_subnormals, has_bf16_overflows); return ptr; }; @@ -311,48 +526,6 @@ void Input::cloneBlobIfRequired() { #endif }; - // The presence of subnormals is better to determined at IR read time. - auto hasSubnormals = [&]() { - if (prec == ov::element::f32) { - uint32_t const* u32data = m_constOp->get_data_ptr(); - - if (!size) - return false; - -#if defined(OPENVINO_ARCH_X86_64) - if (auto fn = jit_has_subnormals_function()) { - static const size_t batch_size = 2048; - const size_t iterations_num = size / batch_size + 1; - - volatile bool has_subnormals = false; - - parallel_for(iterations_num, [&](int n) { - auto ptr = u32data + n * batch_size; - const jit_has_subnormals_base::args_t args = {reinterpret_cast(ptr), - std::min(batch_size, (size_t)(u32data + size - ptr)), - false}; - - fn(&args); - - if (args.hasSubnormals) - has_subnormals = true; - }); - - return has_subnormals; - } -#endif - - uint32_t mantissaMask = 0x007fffff; - uint32_t exponentMask = 0x7f800000; - for (size_t i = 0; i < size; ++i) { - if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) { - return true; - } - } - } - return false; - }; - auto blobKey = [&]() { char ptr[32]; snprintf(ptr, sizeof ptr, "%p", m_constOp->get_data_ptr()); @@ -364,7 +537,7 @@ void Input::cloneBlobIfRequired() { prec != element::string && // IRs already have all subnormals flushed to zero, but in // read_model scenario with directly loaded original model still can have subnormals - isBlobAligned(m_constOp) && (!needFlushDenormalsToZero || !hasSubnormals()) && + isBlobAligned(m_constOp) && !has_subnormals && !has_bf16_overflows && // Blob should be cloned in cache only if original weights are stored on other numa node. // This is possible only in multistream case on multisocket machine. // TODO: don't clone blob for multisocket + multistream case if current stream is run on the numa node where diff --git a/src/plugins/intel_cpu/src/nodes/memory.cpp b/src/plugins/intel_cpu/src/nodes/memory.cpp index 8b29ac8cbfbadb..ba57f33e878ebd 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.cpp +++ b/src/plugins/intel_cpu/src/nodes/memory.cpp @@ -84,7 +84,7 @@ class MemoryStub : public IMemory { m_pMemDesc = desc; } - void load(const IMemory& src, bool ftz = true) const override { + void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const override { OPENVINO_THROW("Unexpected call MemoryStub::load()"); } diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/bf16_convert_saturation.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/bf16_convert_saturation.cpp index 5f43f863c78d9f..17b1f4d1350359 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/bf16_convert_saturation.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/bf16_convert_saturation.cpp @@ -58,7 +58,7 @@ class BF16ConvertSaturation : public testing::WithParamInterface, in_data.resolution = 1; auto thenTensor = ov::test::utils::create_and_fill_tensor(precision, ov::Shape{1}, in_data); - in_data.start_from = 3.40282e+38; + in_data.start_from = 1; in_data.range = 10; in_data.resolution = 2; auto elseTensor = ov::test::utils::create_and_fill_tensor(precision, ov::Shape{2, 1, 32, 32}, in_data); diff --git a/src/plugins/intel_cpu/tests/unit/cpu_tensor_test.cpp b/src/plugins/intel_cpu/tests/unit/cpu_tensor_test.cpp index b8f9634ddb6270..d0ce6e0d47c749 100644 --- a/src/plugins/intel_cpu/tests/unit/cpu_tensor_test.cpp +++ b/src/plugins/intel_cpu/tests/unit/cpu_tensor_test.cpp @@ -71,7 +71,7 @@ class MockIMemory : public IMemory { MOCK_METHOD(const VectorDims&, getStaticDims, (), (const, override)); MOCK_METHOD(void, redefineDesc, (MemoryDescPtr), (override)); - MOCK_METHOD(void, load, (const IMemory&, bool), (const, override)); + MOCK_METHOD(void, load, (const IMemory&, bool, bool), (const, override)); MOCK_METHOD(MemoryBlockPtr, getMemoryBlock, (), (const, override)); MOCK_METHOD(dnnl::memory, getPrimitive, (), (const, override));