Skip to content

Commit

Permalink
x64: jit_generator: replace encoding conditionals with get_encoding()
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun committed Jan 29, 2025
1 parent bb2ecca commit b94b35b
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 47 deletions.
15 changes: 3 additions & 12 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,10 +513,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::store_accumulators_apply_post_ops(
if (brg.is_bf16_emu)
bf16_emu_->vcvtneps2bf16(vmm_low, vmm);
else
vcvtneps2bf16(vmm_low, vmm,
brg.isa_impl == avx2_vnni_2
? Xbyak::VexEncoding
: Xbyak::EvexEncoding);
vcvtneps2bf16(vmm_low, vmm, get_encoding());
if (mask_flag)
vmovdqu16(addr, r_vmm_low);
else
Expand Down Expand Up @@ -827,10 +824,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::comp_dot_product(
bool is_tail_block) {
switch (kernel_type) {
case compute_pad_kernel_t::s8s8_kernel:
vpdpbusd(vmm_acc, vmm_shift(), vmmb,
is_superset(brg.isa_impl, avx512_core)
? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
vpdpbusd(vmm_acc, vmm_shift(), vmmb, get_encoding());
break;
case compute_pad_kernel_t::zero_point_kernel: {
const Vmm vmm_zp = isa_has_masks(brg.isa_impl)
Expand Down Expand Up @@ -1002,10 +996,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::brdgmm_microkernel(int m_blocks,
if (brg.dt_a == data_type::s8 && isa_has_s8s8(brg.isa_impl))
vpdpbssd(vmm_acc, vmma, vmmb);
else
vpdpbusd(vmm_acc, vmma, vmmb,
is_superset(brg.isa_impl, avx512_core)
? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
vpdpbusd(vmm_acc, vmma, vmmb, get_encoding());
}
};

Expand Down
4 changes: 1 addition & 3 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2102,9 +2102,7 @@ void jit_brgemm_kernel_t<Wmm>::dot_product(Vmm v1, Vmm v2, Vmm v3) {
if (brg.dt_a == data_type::s8 && isa_has_s8s8(brg.isa_impl))
vpdpbssd(v1, v3, v2);
else if (brg.has_int8_vnni)
vpdpbusd(v1, v3, v2,
is_superset(brg.isa_impl, avx512_core) ? EvexEncoding
: VexEncoding);
vpdpbusd(v1, v3, v2, get_encoding());
else {
vpmaddubsw(int8_dot_product_temp(), v3, v2);
vpmaddwd(int8_dot_product_temp(), int8_dot_product_temp(),
Expand Down
7 changes: 2 additions & 5 deletions src/cpu/x64/jit_brgemm_conv_comp_pad_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2023 Intel Corporation
* Copyright 2022-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -203,10 +203,7 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t<Vmm>::compute(const int ic_step,
? EVEX_compress_addr(reg_aux_in, oc_offset)
: ptr[reg_aux_in + oc_offset];
if (jcp_.has_int8_vnni) {
vpdpbusd(vmm, vmm_one_bytes, addr,
is_superset(jcp_.isa, avx512_core)
? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
vpdpbusd(vmm, vmm_one_bytes, addr, get_encoding());
} else {
vpmaddubsw(zmm_int8_temp, vmm_one_bytes, addr);
vpmaddwd(zmm_int8_temp, zmm_int8_temp, zmm_one_words);
Expand Down
14 changes: 11 additions & 3 deletions src/cpu/x64/jit_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,16 @@ class jit_generator : public Xbyak::MmapAllocator,
}
}

// The function returns type of encoding (Evex or Vex) depending on the
// system ISA. It's designed to be used with instructions that require
// specific encoding when both encodings are supported on the system.
// Evex would be preferred over Vex when possible.
// The assumption is that both encoding mnemonics are supported by the
// hardware for `avx512_core+` systems.
Xbyak::PreferredEncoding get_encoding() {
return mayiuse(avx512_core) ? Xbyak::EvexEncoding : Xbyak::VexEncoding;
}

// Disallow char-based labels completely
void L(const char *label) = delete;
void L(Xbyak::Label &label) { Xbyak::CodeGenerator::L(label); }
Expand Down Expand Up @@ -2574,9 +2584,7 @@ class jit_generator : public Xbyak::MmapAllocator,
store_bytes(vmm, reg, offset, store_size);
break;
case data_type::bf16:
vcvtneps2bf16(xmm, vmm,
is_valid_isa(avx512_core_bf16) ? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
vcvtneps2bf16(xmm, vmm, get_encoding());
store_bytes(vmm, reg, offset, sizeof(bfloat16_t) * store_size);
break;
case data_type::f16:
Expand Down
6 changes: 2 additions & 4 deletions src/cpu/x64/jit_uni_batch_normalization.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2017-2024 Intel Corporation
* Copyright 2017-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -623,9 +623,7 @@ struct jit_bnorm_t : public jit_generator {

// convert f32 output to bf16
if (!use_bf16_emulation())
vcvtneps2bf16(dst_reg, src_reg,
mayiuse(avx512_core) ? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
vcvtneps2bf16(dst_reg, src_reg, get_encoding());
else
bf16_emu_->vcvtneps2bf16(dst_reg, src_reg);

Expand Down
6 changes: 2 additions & 4 deletions src/cpu/x64/jit_uni_deconv_zp_pad_str_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -144,9 +144,7 @@ void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::compute_step(
if (jcp_.is_depthwise)
uni_vpaddd(result_acc_, result_acc_, wei_vmm);
else if (jcp_.has_vnni)
vpdpbusd(result_acc_, vmm_one_bytes_, wei_vmm,
is_superset(isa, avx512_core) ? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
vpdpbusd(result_acc_, vmm_one_bytes_, wei_vmm, get_encoding());
else {
uni_vpmaddubsw(vmm_tmp_, vmm_one_bytes_, wei_vmm);
uni_vpmaddwd(vmm_tmp_, vmm_tmp_, vmm_one_words_);
Expand Down
6 changes: 2 additions & 4 deletions src/cpu/x64/jit_uni_tbb_batch_normalization.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -429,9 +429,7 @@ struct helper_vmovups_data_t {

// convert f32 output to bf16
if (!bf16_emu_)
h_->vcvtneps2bf16(dst_reg, src_reg,
mayiuse(avx512_core) ? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
h_->vcvtneps2bf16(dst_reg, src_reg, h_->get_encoding());
else
bf16_emu_->vcvtneps2bf16(dst_reg, src_reg);

Expand Down
9 changes: 3 additions & 6 deletions src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t,
void copy_M_loop(bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter);
inline void dot_product(Vmm v1, Vmm v2, Vmm v3) {
if (!avx512_core_dot_product_)
vpdpbusd(v1, v2, v3,
mayiuse(avx512_core) ? EvexEncoding : VexEncoding);
vpdpbusd(v1, v2, v3, get_encoding());
else {
vpmaddubsw(vmm_dot_product_temp, v2, v3);
vpmaddwd(
Expand Down Expand Up @@ -2168,8 +2167,7 @@ struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t,
virtual void copy_4x64(int nrows, int ncolumns, bool zeropad) {}
inline void dot_product(Vmm v1, Vmm v2, Vmm v3) {
if (!avx512_core_dot_product_)
vpdpbusd(v1, v2, v3,
mayiuse(avx512_core) ? EvexEncoding : VexEncoding);
vpdpbusd(v1, v2, v3, get_encoding());
else {
vpmaddubsw(vmm_dot_product_temp, v2, v3);
vpmaddwd(
Expand Down Expand Up @@ -3810,8 +3808,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t

inline void dot_product(Vmm v1, Vmm v2, Vmm v3) {
if (!avx512_core_dot_product_)
vpdpbusd(v1, v2, v3,
mayiuse(avx512_core) ? EvexEncoding : VexEncoding);
vpdpbusd(v1, v2, v3, get_encoding());
else {
vpmaddubsw(vmm_dot_product_temp, v2, v3);
vpmaddwd(
Expand Down
8 changes: 2 additions & 6 deletions src/cpu/x64/utils/jit_io_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ void jit_io_helper_t<Vmm>::prepare_xf16_data_to_store(const Vmm &vmm) {
typename vreg_traits<Vmm>::Vmm_lower_t(vmm.getIdx());

if (data_type_ == data_type::bf16)
host_->vcvtneps2bf16(cvt_lower_vmm, vmm,
mayiuse(avx512_core) ? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
host_->vcvtneps2bf16(cvt_lower_vmm, vmm, host_->get_encoding());
else
host_->uni_vcvtps2phx(cvt_lower_vmm, vmm);
}
Expand Down Expand Up @@ -852,9 +850,7 @@ void jit_io_helper_t<Vmm>::store_bf16(
if (bf16_emu_)
bf16_emu_->vcvtneps2bf16(cvt_lower_vmm, src_vmm);
else
host_->vcvtneps2bf16(cvt_lower_vmm, src_vmm,
mayiuse(avx512_core) ? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
host_->vcvtneps2bf16(cvt_lower_vmm, src_vmm, host_->get_encoding());

if (io_conf_.nt_stores_enabled_)
host_->uni_vmovntps(dst_addr, cvt_lower_vmm);
Expand Down

0 comments on commit b94b35b

Please sign in to comment.