Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] bf16 constant saturation during model compilation stage #28542

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions src/plugins/intel_cpu/src/cpu_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "memory_desc/cpu_memory_desc_utils.h"
#include "nodes/common/cpu_memcpy.h"
#include "nodes/reorder.h"
#include "utils/bfloat16.hpp"
#include "utils/debug_capabilities.h"
#if defined(__linux__)
# include <sys/syscall.h> /* Definition of SYS_* constants */
Expand All @@ -31,19 +32,26 @@ BlockedMemoryDescPtr IMemory::getDescWithType<BlockedMemoryDesc, 0, 0>() const {
}

namespace {
inline void setSubnormalsToZero(float* data, size_t size) {
inline void setSubnormalsToZeroAndbf16Saturation(float* data, size_t size, bool ftz, bool bf16saturation) {
uint32_t* u32data = reinterpret_cast<uint32_t*>(data);
float* floatdata = reinterpret_cast<float*>(data);
for (size_t i = 0; i < size; ++i) {
if ((u32data[i] & (0xFF << 23)) == 0) {
if (ftz && ((u32data[i] & (0xFF << 23)) == 0)) {
u32data[i] = 0;
} else if (bf16saturation && !std::isnan(floatdata[i]) && !std::isinf(floatdata[i])) {
floatdata[i] = (floatdata[i] < static_cast<float>(std::numeric_limits<ov::bfloat16>::lowest()))
? static_cast<float>(std::numeric_limits<ov::bfloat16>::lowest())
: (floatdata[i] > static_cast<float>(std::numeric_limits<ov::bfloat16>::max()))
? static_cast<float>(std::numeric_limits<ov::bfloat16>::max())
: floatdata[i];
}
}
}

void transferData(const IMemory& src, const IMemory& dst, bool ftz) {
void transferData(const IMemory& src, const IMemory& dst, bool ftz, bool bf16saturation) {
node::Reorder::reorderData(src, dst);

if (!ftz) {
if (!ftz && !bf16saturation) {
return;
}
if (src.getDesc().getPrecision() != ov::element::f32 || dst.getDesc().getPrecision() != ov::element::f32) {
Expand All @@ -63,7 +71,7 @@ void transferData(const IMemory& src, const IMemory& dst, bool ftz) {
// actual FTZ
auto* memData = static_cast<float*>(dst.getData());
memData += offset;
setSubnormalsToZero(memData, dst.getSize() / sizeof(float));
setSubnormalsToZeroAndbf16Saturation(memData, dst.getSize() / sizeof(float), ftz, bf16saturation);
}

} // namespace
Expand Down Expand Up @@ -126,11 +134,11 @@ void Memory::create(MemoryDescPtr desc, const void* data, bool pads_zeroing) {
}
}

void Memory::load(const IMemory& src, bool ftz) const {
void Memory::load(const IMemory& src, bool ftz, bool bf16saturation) const {
if (src.getDesc().getPrecision() == element::string) {
OPENVINO_THROW("[CPU] Memory object cannot load string data.");
}
transferData(src, *this, ftz);
transferData(src, *this, ftz, bf16saturation);
}

void Memory::nullify() {
Expand Down Expand Up @@ -272,12 +280,12 @@ StringMemory::StringMemory(dnnl::engine engine, MemoryDescPtr desc, const void*
}
}

void StringMemory::load(const IMemory& src, bool ftz) const {
void StringMemory::load(const IMemory& src, bool ftz, bool bf16saturation) const {
if (src.getDesc().getPrecision() != element::string) {
OPENVINO_THROW("[CPU] String memory cannot load a non-string object.");
}

transferData(src, *this, false);
transferData(src, *this, false, false);
}

void* StringMemory::getData() const {
Expand Down Expand Up @@ -471,11 +479,11 @@ void StaticMemory::redefineDesc(MemoryDescPtr desc) {
OPENVINO_THROW("Unexpected: Memory descriptor may not be modified in StaticMemory object");
}

void StaticMemory::load(const IMemory& src, bool ftz) const {
void StaticMemory::load(const IMemory& src, bool ftz, bool bf16saturation) const {
if (src.getDesc().getPrecision() == element::string) {
OPENVINO_THROW("[CPU] StaticMemory cannot load string data.");
}
transferData(src, *this, ftz);
transferData(src, *this, ftz, bf16saturation);
}

MemoryBlockPtr StaticMemory::getMemoryBlock() const {
Expand Down
8 changes: 4 additions & 4 deletions src/plugins/intel_cpu/src/cpu_memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class IMemory {
// Caution!!! This action invalidates the previous data layout. The old data may become unreachable.
virtual void redefineDesc(MemoryDescPtr desc) = 0;

virtual void load(const IMemory& src, bool ftz = true) const = 0;
virtual void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const = 0;

virtual MemoryBlockPtr getMemoryBlock() const = 0;

Expand Down Expand Up @@ -260,7 +260,7 @@ class StaticMemory final : public IMemory {
// Always throws since a static memory descriptor should not be modified
void redefineDesc(MemoryDescPtr desc) override;

void load(const IMemory& src, bool ftz = true) const override;
void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const override;

MemoryBlockPtr getMemoryBlock() const override;

Expand Down Expand Up @@ -315,7 +315,7 @@ class Memory : public IMemory {

void redefineDesc(MemoryDescPtr desc) override;

void load(const IMemory& src, bool ftz = true) const override;
void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const override;
void nullify() override;

dnnl::engine getEngine() const {
Expand Down Expand Up @@ -421,7 +421,7 @@ class StringMemory : public IMemory {

void redefineDesc(MemoryDescPtr desc) override;

void load(const IMemory& src, bool ftz = false) const override;
void load(const IMemory& src, bool ftz = false, bool bf16saturation = false) const override;

MemoryBlockPtr getMemoryBlock() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter {
conversion_mode mode = conversion_mode::default_mode)
: jit_emitter(host, host_isa, exec_prc),
mode_(mode) {
prepare_table();
if ((!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) &&
!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2)) ||
mode_ == conversion_mode::saturation_mode) {
prepare_table();
}
}

size_t get_inputs_num() const override {
Expand Down
10 changes: 2 additions & 8 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2895,15 +2895,9 @@ void Eltwise::prepareParams() {

// FP32 constant inputs may contain values out of BF16 representable range. In case output precision is BF16 we
// choose "saturation" mode for fp32->bf16 conversion procedure to prevent getting -Inf/+Inf values in the
// outputs. Since "saturation" conversion is more time consuming, better solution would be to clamp constants on
// compilation stage (ticket: 159589).
// outputs. Since "saturation" conversion during kernel runtime is more time consuming, current solution is
// clamp constants on compilation stage.
key.doOutputSaturation = false;
for (size_t i = 0; i < getParentEdges().size(); i++) {
if (getParentEdgeAt(i)->getParent()->isConstant()) {
key.doOutputSaturation = true;
break;
}
}

auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, buildExecutor);
Expand Down
Loading
Loading