Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

x64: brgemm kernel: update Vmm usage #2579

Merged
merged 1 commit into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 53 additions & 26 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,47 +209,74 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
// TODO: Calculating the number of available registers should be re-factored
// to use one code here and in brgemm kernel generator on
// "max_effective_vregs" calculation
int max_isa_regs = isa_num_vregs(brg->isa_impl);
const int max_bcst_regs = brg->n_bcast_1_load ? 0 : 1;
const bool req_compensation = brg->req_s8s8_compensation
|| brg->zp_type_a != brgemm_broadcast_t::none;
const int load_regs = brg->n_bcast_1_load ? 1 : adj_ld_block2;
const bool req_zp_a_comp_pads
= (brg->req_cal_comp_pads || brg->brgattr.max_top_vpad > 0
|| brg->brgattr.max_bottom_vpad > 0)
&& brg->zp_type_a != brgemm_broadcast_t::none;
const int beta_regs = !one_of(brg->beta, 1.f, 0.f);

// -------------- whole kernel --------------
// To support the f16 vnni B matrix on non-AMX we need to use two Vmm
// registers for permutation in brgemm kernel
// registers for permutation in brgemm kernel:
// see f16_perm_even_vreg_ and f16_perm_odd_vreg_ in brgemm kernel
const int b_vnni_regs = brg->is_f16_b_non_amx_vnni() ? 2 : 0;

const int max_isa_regs = isa_num_vregs(brg->isa_impl);
// non-VNNI INT8 dot product required 2 temp vectors:
// see int8_ones_words() and int8_dot_product_temp() in brgemm kernel
const int non_int8_vnni_regs
= (brg->is_int8 && !brg->has_int8_vnni) ? 2 : 0;

max_isa_regs -= b_vnni_regs + non_int8_vnni_regs;

// --------------- microkernel ---------------

// see vmm_inp_shift() in brgemm kernel
const int compensation_regs = brg->req_s8s8_compensation
|| brg->zp_type_a != brgemm_broadcast_t::none
? 1
: 0;

// see vmm_zp_a_shift(), vmm_one_bytes() in brgemm kernel
const int zp_a_comp_pads_regs = req_zp_a_comp_pads ? 2 : 0;

const int microkernel_regs = zp_a_comp_pads_regs + compensation_regs;

const auto microkernel_max_reg_count
= max_isa_regs - microkernel_regs - load_regs - max_bcst_regs;

auto microkernel_max_bcast_block
= microkernel_max_reg_count / (adj_ld_block2 + brg->n_bcast_1_load);

// ----- post-ops and store accumulators -----

const int beta_regs = !one_of(brg->beta, 1.f, 0.f);

const int postops_regs = brg->attr()
? injector::aux_vec_count(
brg->attr()->post_ops_, brg->isa_impl, true)
: 0;

// note: the 'adj_ld_block2' already removes the necessary registers
// for 'embd_bcst'
auto max_reg_count = max_isa_regs - max_bcst_regs - beta_regs
- req_compensation - req_zp_a_comp_pads - b_vnni_regs;
if (req_zp_a_comp_pads)
max_reg_count
= nstl::min(max_reg_count, max_isa_regs - max_bcst_regs - 5);

int max_bcast_block = max_reg_count
- nstl::max(brg->n_bcast_1_load ? 1 : adj_ld_block2, postops_regs);

if (brg->is_bf16_emu) {
// in theory, vmm bf16_emu register indices overlap with other vmm
// registers related to 'max_bcast_block'
assert(is_superset(brg->isa_impl, avx512_core));
constexpr int bf16_emu_reg_count = 28;
max_bcast_block = nstl::min(max_bcast_block, bf16_emu_reg_count);
}
// Emulators: fp8 emulation are supported for amx only
// In theory, vmm bf16_emu register indices overlap with other vmm
// registers related to 'max_bcast_block'
assert(IMPLICATION(
brg->is_bf16_emu, is_superset(brg->isa_impl, avx512_core)));
const int bf16_emu_regs = brg->is_bf16_emu ? 4 : 0;

const auto store_regs = nstl::max(beta_regs,
nstl::max(
postops_regs, nstl::max(compensation_regs, bf16_emu_regs)));

const auto store_max_reg_count = max_isa_regs - store_regs;

auto store_max_bcast_block = store_max_reg_count / adj_ld_block2;

// non-VNNI INT8 dot product required 2 temp vectors
if (brg->is_int8 && !brg->has_int8_vnni) max_bcast_block -= 2;
// ------------ final calculation ------------

max_bcast_block /= (adj_ld_block2 + brg->n_bcast_1_load);
const auto max_bcast_block
= nstl::min(microkernel_max_bcast_block, store_max_bcast_block);

return max_bcast_block;
}
Expand Down
90 changes: 42 additions & 48 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ struct jit_brgemm_kernel_t : public jit_base_brgemm_kernel_t {
// 'fp8_to_f16_upconvert()' param and would collision with these
// emulation vmms
f8_e5m2_emulator_ = utils::make_unique<fp8_emulation_e5m2_t>(
this, xmm_fp8_emu_aux1, xmm_fp8_emu_aux2,
xmm_fp8_emu_aux3, kmask_fp8_aux, reg64_fp8_aux);
this, vmm_fp8_emu_aux1(), vmm_fp8_emu_aux2(),
vmm_fp8_emu_aux3(), kmask_fp8_aux, reg64_fp8_aux);
if (one_of(data_type::f8_e4m3, brg.dt_a, brg.dt_b, brg.dt_c,
brg.dt_d)
|| has_f8_e4m3_binary_postops)
f8_e4m3_emulator_ = utils::make_unique<fp8_emulation_e4m3_t>(
this, xmm_fp8_emu_aux1, xmm_fp8_emu_aux2,
xmm_fp8_emu_aux3, xmm_fp8_emu_aux4, xmm_fp8_emu_aux5,
reg64_fp8_aux);
this, vmm_fp8_emu_aux1(), vmm_fp8_emu_aux2(),
vmm_fp8_emu_aux3(), vmm_fp8_emu_aux4(),
vmm_fp8_emu_aux5(), reg64_fp8_aux);
}

if (brg.with_eltwise || brg.with_binary || brg.with_sum) {
Expand Down Expand Up @@ -324,6 +324,10 @@ struct jit_brgemm_kernel_t : public jit_base_brgemm_kernel_t {
}

Vmm vmm_tail_mask() { return vmm_tmp(1); }
Vmm vmm_beta() { return vmm_tmp(1); }
Vmm vmm_lbound() { return vmm_tmp(1); }
Vmm vmm_ubound() { return vmm_tmp(0); }

Vmm vmm_one_bytes() const noexcept { return Vmm(3); }
Vmm vmm_zp_a_shift() const noexcept { return Vmm(2); }
Vmm vmm_inp_shift() const noexcept { return Vmm(1); }
Expand All @@ -336,11 +340,11 @@ struct jit_brgemm_kernel_t : public jit_base_brgemm_kernel_t {
// note: zmm reserv_5 is not necessary since it's only used for 'vdpbf16ps'

// fp8 emulation convert
Vmm xmm_fp8_emu_aux1 = Vmm(1);
Vmm xmm_fp8_emu_aux2 = Vmm(2);
Vmm xmm_fp8_emu_aux3 = Vmm(3);
Vmm xmm_fp8_emu_aux4 = Vmm(4);
Vmm xmm_fp8_emu_aux5 = Vmm(5);
Vmm vmm_fp8_emu_aux1() const noexcept { return Vmm(1); }
Vmm vmm_fp8_emu_aux2() const noexcept { return Vmm(2); }
Vmm vmm_fp8_emu_aux3() const noexcept { return Vmm(3); }
Vmm vmm_fp8_emu_aux4() const noexcept { return Vmm(4); }
Vmm vmm_fp8_emu_aux5() const noexcept { return Vmm(5); }

Zmm zmm_tmp_1() const noexcept { return Zmm(1); }

Expand All @@ -352,8 +356,12 @@ struct jit_brgemm_kernel_t : public jit_base_brgemm_kernel_t {
return Vmm(isa_num_vregs(brg.isa_impl) - 2);
}

Zmm f16_perm_even_vreg_ = Zmm(isa_num_vregs(brg.isa_impl) - 1);
Zmm f16_perm_odd_vreg_ = Zmm(isa_num_vregs(brg.isa_impl) - 2);
Vmm f16_perm_even_vreg() const noexcept {
return Vmm(isa_num_vregs(brg.isa_impl) - 1);
}
Vmm f16_perm_odd_vreg() const noexcept {
return Vmm(isa_num_vregs(brg.isa_impl) - 2);
}

Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask) const;
Expand Down Expand Up @@ -1094,11 +1102,10 @@ void jit_brgemm_kernel_t<Wmm>::apply_alpha_beta(
const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required;
const bool need_init_beta_vmm = brg.beta != 1.f;
auto vmm_prev_dst = vmm_tmp(0);
auto vmm_beta = vmm_tail_mask();
if (need_init_beta_vmm) {
mov(reg_tmp_gpr, float2int(static_cast<float>(brg.beta)));
uni_vmovq(Xmm(vmm_beta.getIdx()), reg_tmp_gpr);
uni_vbroadcastss(vmm_beta, Xmm(vmm_beta.getIdx()));
uni_vmovq(Xmm(vmm_beta().getIdx()), reg_tmp_gpr);
uni_vbroadcastss(vmm_beta(), Xmm(vmm_beta().getIdx()));
}

if (brg.is_runtime_ldc && bd_block > 1)
Expand Down Expand Up @@ -1133,7 +1140,7 @@ void jit_brgemm_kernel_t<Wmm>::apply_alpha_beta(
if (brg.beta == 1.f)
uni_vaddps(vmm, vmm, vmm_prev_dst);
else
uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_beta);
uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_beta());
}
if (brg.is_runtime_ldc && bd_block > 1 && ld == ld_block2 - 1)
add(reg_aux_C, ptr[rsp + reg_C_shift_bytes_offs_]);
Expand Down Expand Up @@ -1383,16 +1390,14 @@ void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(

const bool dt_requires_saturation
= one_of(brg.dt_d, data_type::u8, data_type::s8, data_type::s32);
auto vmm_lbound = vmm_tail_mask();
auto vmm_ubound = vmm_tmp(0);
assert(vmm_lbound.getIdx() != vmm_ubound.getIdx());
assert(vmm_lbound().getIdx() != vmm_ubound().getIdx());
if (dt_requires_saturation) {
init_saturate_f32(
vmm_lbound, vmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d);
init_saturate_f32(vmm_lbound(), vmm_ubound(), reg_tmp_gpr,
data_type::f32, brg.dt_d);
for (int bd = 0; bd < bd_block; bd++) {
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm = accm(ld_block2, bd, ld);
saturate_cvt_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d);
saturate_cvt_f32(vmm, vmm_lbound(), vmm_ubound(), brg.dt_d);
}
}
// below call is not required as s32 doesn't use vmm_lbound
Expand Down Expand Up @@ -1563,14 +1568,12 @@ void jit_brgemm_kernel_t<Wmm>::store_accumulators_without_post_ops(
&& !IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd);

if (dt_requires_saturation) {
auto vmm_ubound = vmm_tmp(0);
auto vmm_lbound = vmm_tmp(1);
init_saturate_f32(
vmm_lbound, vmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d);
init_saturate_f32(vmm_lbound(), vmm_ubound(), reg_tmp_gpr,
data_type::f32, brg.dt_d);
for (int bd = 0; bd < bd_block; bd++) {
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm = accm(ld_block2, bd, ld);
saturate_cvt_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d);
saturate_cvt_f32(vmm, vmm_lbound(), vmm_ubound(), brg.dt_d);
}
}
// below call is not required as s32 doesn't use vmm_lbound
Expand Down Expand Up @@ -2157,7 +2160,7 @@ void jit_brgemm_kernel_t<Wmm>::compute_int8_compensation(int rd_loop, int bd_b,
}
};

if (brg.n_bcast_1_load && brg.zp_type_a != brgemm_broadcast_t::none) {
if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) {
mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
const auto reg32_scratch = reg_zp_a_input_shift.cvt32();
mov(reg32_scratch, 0x1010101);
Expand Down Expand Up @@ -2216,6 +2219,13 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
} else
rd_loop = brg.rd_block;

if (brg.req_s8s8_compensation) {
mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
mov(reg_s8_input_shift, 128);
uni_vpbroadcastb(vmm_inp_shift(), reg_s8_input_shift.cvt8());
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
}

auto broadcast_A = [this, rd_tail_size, is_rd_tail, rd_loop,
rows_for_rd_tail,
bd_e](Vmm vmm_bcast, int bd, int rd) {
Expand Down Expand Up @@ -2279,9 +2289,9 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
const auto vnni_addr = ptr[reg_aux_B + actual_B_offset];
vmovups(vmm_load, vnni_addr);
if (rd % 2 == 0)
vpermw(vmm_load, f16_perm_even_vreg_, vmm_load);
vpermw(vmm_load, f16_perm_even_vreg(), vmm_load);
else
vpermw(vmm_load, f16_perm_odd_vreg_, vmm_load);
vpermw(vmm_load, f16_perm_odd_vreg(), vmm_load);
vcvtph2psx(vmm_load, Vmm_lower_t(vmm_load.getIdx()));
} else if (is_ld_tail && !is_superset(brg.isa_impl, avx512_core)) {
load_bytes(vmm_load, addr, ldb_B_offset(0, true));
Expand Down Expand Up @@ -2443,22 +2453,6 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
mov(reg_stride_ldb, brg.rd_step * brg.typesize_B * brg.LDB);
}

if (brg.req_s8s8_compensation) {
mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
mov(reg_s8_input_shift, 128);
uni_vpbroadcastb(vmm_inp_shift(), reg_s8_input_shift.cvt8());
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
}
if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) {
mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
const auto reg32_scratch = reg_zp_a_input_shift.cvt32();
mov(reg32_scratch, 0x1010101);
uni_vpbroadcastd(vmm_one_bytes(), reg32_scratch);
mov(reg32_scratch, ptr[rsp + reg_zp_a_val_offs_]);
uni_vpbroadcastd(vmm_zp_a_shift(), reg32_scratch);
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
}

if (brg.brgattr.max_bs > 1) mov(reg_BS_loop, reg_BS);
L_aligned(BS_loop_label, 64);
{
Expand Down Expand Up @@ -2844,9 +2838,9 @@ void jit_brgemm_kernel_t<Wmm>::generate() {

if (brg.is_f16_b_non_amx_vnni()) {
mov(reg_tmp_gpr, f16_perm_even_table_);
vmovups(f16_perm_even_vreg_, ptr[reg_tmp_gpr]);
vmovups(f16_perm_even_vreg(), ptr[reg_tmp_gpr]);
mov(reg_tmp_gpr, f16_perm_odd_table_);
vmovups(f16_perm_odd_vreg_, ptr[reg_tmp_gpr]);
vmovups(f16_perm_odd_vreg(), ptr[reg_tmp_gpr]);
}

if (brg.is_tmm && brg.amx_wary_k_tail()) {
Expand Down