diff --git a/src/cpu/x64/jit_brgemm_post_ops.cpp b/src/cpu/x64/jit_brgemm_post_ops.cpp index 1ea3558c533..f5e5c07f26f 100644 --- a/src/cpu/x64/jit_brgemm_post_ops.cpp +++ b/src/cpu/x64/jit_brgemm_post_ops.cpp @@ -70,7 +70,7 @@ dnnl::impl::cpu::x64::jit_brgemm_kernel_diff_bias_t:: // Only reduction for `src` is supported. assert(reduce_kind_ == matmul_reduce_kind::src); // `src` matrix is assumed to have a row major layout. - assert(bgmmc.treat_transposed_A_as_plain || bgmmc.use_buffer_a); + assert(bgmmc.treat_A_as_plain || bgmmc.use_buffer_a); } template diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 9e5b6021c01..c6d6710860f 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -741,7 +741,7 @@ struct matmul_avx512_blocking_params_t { bgmmc.use_buffer_c = is_buffer_c_required( bgmmc.acc_dt, bgmmc.dst_dt, bgmmc.with_sum); bgmmc.LDA = bgmmc.adjust_a_strides || bgmmc.use_buffer_a - || bgmmc.treat_transposed_A_as_plain + || bgmmc.treat_A_as_plain ? get_actual_lda(bgmmc.use_buffer_a, bgmmc.tr_a_dt_sz) : bgmmc.A_strides[1] / bgmmc.a_dt_sz; } @@ -1498,16 +1498,16 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, VERBOSE_UNSUPPORTED_TAG); const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag); - // if M == 1 we can still treat formally transposed A as plain - // and avoid copy routine creation/execution - bgmmc.treat_transposed_A_as_plain = transposed_A && bgmmc.M == 1; - bgmmc.transposed_A = ((transposed_A && !bgmmc.treat_transposed_A_as_plain) + // When M == 1 MatMul always considers A to be non-transposed even if A md + // was created using "ba" tag. + bgmmc.treat_A_as_plain = bgmmc.M == 1; + bgmmc.transposed_A = ((transposed_A && !bgmmc.treat_A_as_plain) || bgmmc.src_tag == adbc); // For batched problems with plain A and C and fully broadcasted across B // we can merge all the batch dimensions into M if broadcast strategies // set is limited for binary post-ops const bool plain_A_layout = bm_conf_utils.check_is_plain(bgmmc.src_tag) - || bgmmc.treat_transposed_A_as_plain; + || bgmmc.treat_A_as_plain; const bool merge_batch_dims_into_M = bgmmc.batch > 1 && bgmmc.bcast_B_desc.bcast_across_all_batch_dims && plain_A_layout && helper.is_src_dst_layout_batch_fusable() @@ -1615,7 +1615,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, // We need to correct A_strides if batched dimensions are merged in M and // A layout is formally transposed but could be treated as plain bgmmc.adjust_a_strides = merge_batch_dims_into_M - && (src_d.matches_tag(acbd) || bgmmc.treat_transposed_A_as_plain); + && (src_d.matches_tag(acbd) || bgmmc.treat_A_as_plain); if (bgmmc.adjust_a_strides) bgmmc.A_strides[1] = bgmmc.A_strides[2]; // We need to correct C_strides if batched dimensions are merged in M and @@ -2199,9 +2199,7 @@ void matmul_amx_blocking_params_t::update_configuration( dim_t matmul_amx_blocking_params_t::get_actual_lda() { if (!use_buffer_a_) - return treat_transposed_A_as_plain - ? K - : A_strides[1 - transposed_A] / a_dt_sz; + return treat_A_as_plain ? K : A_strides[1 - transposed_A] / a_dt_sz; constexpr int bytes_in_cacheline = 64; const int elems_in_cacheline = bytes_in_cacheline / a_dt_sz; diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp index 006e7f1c1ff..d8fd035159b 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp @@ -180,7 +180,7 @@ struct brgemm_matmul_conf_t { bool transposed_A; bool transposed_B; bool blocked_B; - bool treat_transposed_A_as_plain; + bool treat_A_as_plain; // A_strides could be changed during // Matmul conf initialization in case when batches merged into M.