diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index 5ee572a68a4..223bf03b9a0 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -234,9 +234,6 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2; max_bcast_block /= adj_ld_block2; - if (brg->with_src_dyn_quant) { - max_bcast_block /= 2; - } return max_bcast_block; } diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index aa7bce903ae..6cfdafec715 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -2298,11 +2298,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift()); }; - auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) { - int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld_block - (bd * ld_block + ld); - return Vmm(idx); - }; - auto vmm_zero_point = [&](int ld) { int idx = isa_num_vregs(brg.isa_impl) - 3 - ld; return Vmm(idx); @@ -2368,9 +2363,14 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]); + const int vec_size = vreg_traits::vlen; + auto accums_stack_space = bd_e * ld_block2 * vec_size; + sub(rsp, accums_stack_space); for (int bd = bd_b; bd < bd_e; bd++) { for (int ld = 0; ld < ld_block2; ld++) { - auto vmm_accm = vmm_accm_tmp(ld_block2, bd, ld); + auto vmm_accm = accm(ld_block2, bd, ld); + vmovups(ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm); + uni_vxorps(vmm_accm, vmm_accm, vmm_accm); } } @@ -2409,14 +2409,14 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, + brg.LDB * brg.rd_block * brg.typesize_B]); } for (int ld = 0; ld < ld_block2; ld++) { - auto vmm = vmm_accm_tmp(ld_block2, bd, ld); + auto vmm = accm(ld_block2, bd, ld); vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding); } if (brg.with_wei_decomp_zero_points) { uni_vpxor(bcst(), bcst(), vmm_neg_one); uni_vpsubb(bcst(), bcst(), vmm_neg_one); for (int ld = 0; ld < ld_block2; ld++) { - auto vmm = vmm_accm_tmp(ld_block2, bd, ld); + auto vmm = accm(ld_block2, bd, ld); Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld); vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding); } @@ -2426,7 +2426,7 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, auto reg_local_src_scales = reg_local_wei_zp; auto vmm_src_scales = bcst(); - mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]); + mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]); for (int bd = bd_b; bd < bd_e; bd++) { uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]); @@ -2438,15 +2438,17 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, } } for (int ld = 0; ld < ld_block2; ld++) { - auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld); auto vmm_accm = accm(ld_block2, bd, ld); - uni_vcvtdq2ps(vmm_accm_aux, vmm_accm_aux); - uni_vmulps(vmm_accm_aux, vmm_accm_aux, vmm_src_scales); - uni_vfmadd231ps(vmm_accm, vmm_accm_aux, load(ld)); + uni_vcvtdq2ps(vmm_accm, vmm_accm); + uni_vmulps(vmm_accm, vmm_accm, vmm_src_scales); + uni_vmulps(load(ld), vmm_accm, load(ld)); + uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]); + uni_vaddps(vmm_accm, vmm_accm, load(ld)); } } + add(rsp, accums_stack_space); mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); diff --git a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp index c1a247c241c..929bf649867 100644 --- a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp @@ -1441,11 +1441,6 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, jbgp.wei_zero_points_ic_group_size = div_up(jbgp.ic, attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[1]); } - // todo: fix avx2 brgemm kernel behavior for non scalar zp - if (!is_superset(isa, avx512_core) && attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[0] != 1) { - jbgp.with_src_dynamic_quant = false; - } - jbgp.wei_decomp_zero_points_dt = attr.zero_points_.get_data_type(DNNL_ARG_WEIGHTS); if (!one_of(jbgp.wei_decomp_zero_points_dt, f32, u8)) return status::unimplemented;