Skip to content

Commit

Permalink
cpu: x64: matmul: fix handling cases when M == 1
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Feb 1, 2025
1 parent b20dcfc commit d891dd6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_brgemm_post_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ dnnl::impl::cpu::x64::jit_brgemm_kernel_diff_bias_t<Vmm>::
// Only reduction for `src` is supported.
assert(reduce_kind_ == matmul_reduce_kind::src);
// `src` matrix is assumed to have a row major layout.
assert(bgmmc.treat_transposed_A_as_plain || bgmmc.use_buffer_a);
assert(bgmmc.treat_A_as_plain || bgmmc.use_buffer_a);
}

template <typename Vmm>
Expand Down
18 changes: 8 additions & 10 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ struct matmul_avx512_blocking_params_t {
bgmmc.use_buffer_c = is_buffer_c_required(
bgmmc.acc_dt, bgmmc.dst_dt, bgmmc.with_sum);
bgmmc.LDA = bgmmc.adjust_a_strides || bgmmc.use_buffer_a
|| bgmmc.treat_transposed_A_as_plain
|| bgmmc.treat_A_as_plain
? get_actual_lda(bgmmc.use_buffer_a, bgmmc.tr_a_dt_sz)
: bgmmc.A_strides[1] / bgmmc.a_dt_sz;
}
Expand Down Expand Up @@ -1498,16 +1498,16 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
VERBOSE_UNSUPPORTED_TAG);

const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag);
// if M == 1 we can still treat formally transposed A as plain
// and avoid copy routine creation/execution
bgmmc.treat_transposed_A_as_plain = transposed_A && bgmmc.M == 1;
bgmmc.transposed_A = ((transposed_A && !bgmmc.treat_transposed_A_as_plain)
// When M == 1 MatMul always considers A to be non-transposed even if A md
// was created using "ba" tag.
bgmmc.treat_A_as_plain = bgmmc.M == 1;
bgmmc.transposed_A = ((transposed_A && !bgmmc.treat_A_as_plain)
|| bgmmc.src_tag == adbc);
// For batched problems with plain A and C and fully broadcasted across B
// we can merge all the batch dimensions into M if broadcast strategies
// set is limited for binary post-ops
const bool plain_A_layout = bm_conf_utils.check_is_plain(bgmmc.src_tag)
|| bgmmc.treat_transposed_A_as_plain;
|| bgmmc.treat_A_as_plain;
const bool merge_batch_dims_into_M = bgmmc.batch > 1
&& bgmmc.bcast_B_desc.bcast_across_all_batch_dims && plain_A_layout
&& helper.is_src_dst_layout_batch_fusable()
Expand Down Expand Up @@ -1615,7 +1615,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
// We need to correct A_strides if batched dimensions are merged in M and
// A layout is formally transposed but could be treated as plain
bgmmc.adjust_a_strides = merge_batch_dims_into_M
&& (src_d.matches_tag(acbd) || bgmmc.treat_transposed_A_as_plain);
&& (src_d.matches_tag(acbd) || bgmmc.treat_A_as_plain);
if (bgmmc.adjust_a_strides) bgmmc.A_strides[1] = bgmmc.A_strides[2];

// We need to correct C_strides if batched dimensions are merged in M and
Expand Down Expand Up @@ -2199,9 +2199,7 @@ void matmul_amx_blocking_params_t::update_configuration(

dim_t matmul_amx_blocking_params_t::get_actual_lda() {
if (!use_buffer_a_)
return treat_transposed_A_as_plain
? K
: A_strides[1 - transposed_A] / a_dt_sz;
return treat_A_as_plain ? K : A_strides[1 - transposed_A] / a_dt_sz;

constexpr int bytes_in_cacheline = 64;
const int elems_in_cacheline = bytes_in_cacheline / a_dt_sz;
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/matmul/brgemm_matmul_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ struct brgemm_matmul_conf_t {
bool transposed_A;
bool transposed_B;
bool blocked_B;
bool treat_transposed_A_as_plain;
bool treat_A_as_plain;

// A_strides could be changed during
// Matmul conf initialization in case when batches merged into M.
Expand Down

0 comments on commit d891dd6

Please sign in to comment.