diff --git a/src/gpu/intel/ocl/micro_sdpa.cl b/src/gpu/intel/ocl/micro_sdpa.cl index 34d52331d43..350d5fef592 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cl +++ b/src/gpu/intel/ocl/micro_sdpa.cl @@ -223,6 +223,9 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q, A += DST_OFF(b1, b0, 0, 0, 0); #if WITH_ATTN_MASK msk += MSK_OFF(b1 % MSK_D0, b0 % MSK_D1, 0, 0); +#ifndef BLOCK_MSK + int mask_aligned = (((size_t)msk) % 4) == 0; +#endif #endif #if KEY_SCALES @@ -320,9 +323,17 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q, /* Load mask. No remainder handling needed assuming k block size is a power of 2. */ mask_tile_type mask_tile; #if BROADCAST_MASK_Q +#if BLOCK_MSK tile_load_block(&mask_tile, msk, 0, k0 + sg_i0_kq, 0); #else - tile_load_t(&mask_tile, msk, q, k, q, sg_j0_kq + wg_j0, k0 + sg_i0_kq); + if (mask_aligned) { + tile_load_block(&mask_tile, msk, 0, k0 + sg_i0_kq, 0); + } else { + tile_load_full(&mask_tile, msk, 0, k0 + sg_i0_kq, 0); + } +#endif +#else + tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq); #endif #endif diff --git a/src/gpu/intel/ocl/micro_sdpa.cpp b/src/gpu/intel/ocl/micro_sdpa.cpp index 0555ee2d1cc..a02f45706a3 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cpp +++ b/src/gpu/intel/ocl/micro_sdpa.cpp @@ -264,7 +264,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { problem_kq.B.layout = MatrixLayout::Pr; problem_kq.C.layout = MatrixLayout::T; - problem_kq.A.setAlignment(alignmentForLD(d->head_size() * problem.Ta)); + const memory_desc_wrapper key_mdw(key_md()); + auto ldk = static_cast( + gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size()); + problem_kq.A.setAlignment(alignmentForLD(ldk)); problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM problem_kq.B.crosspack = 2; problem_kq.B.tileR = into(d_max()); @@ -331,7 +334,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { problem_vs.B.layout = MatrixLayout::Pr; problem_vs.C.layout = MatrixLayout::N; - problem_vs.A.setAlignment(alignmentForLD(d->head_size() * problem.Ta)); + const memory_desc_wrapper val_mdw(val_md()); + auto ldv = static_cast( + gemm_desc_t::get_ld(*val_md()) * val_mdw.data_type_size()); + problem_vs.A.setAlignment(alignmentForLD(ldv)); problem_vs.B.setAlignment(64); // S is packed in SLM problem_vs.B.crosspack = 16; sizes.m = d->values(); @@ -407,6 +413,7 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) { auto ldk = gemm_desc_t::get_ld(*pd()->key_md()) * key_mdw.data_type_size(); auto ldv = gemm_desc_t::get_ld(*pd()->val_md()) * val_mdw.data_type_size(); auto lda = gemm_desc_t::get_ld(*pd()->dst_md()) * dst_mdw.data_type_size(); + auto ldmsk = pd()->attn_mask_md()->dims[3] * msk_mdw.data_type_size(); kernel_ctx.define_int("Q_ALIGN", jit::alignmentForLD(int(ldq))); kernel_ctx.define_int("K_ALIGN", jit::alignmentForLD(int(ldk))); kernel_ctx.define_int("V_ALIGN", jit::alignmentForLD(int(ldv))); @@ -477,6 +484,7 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) { if (d_full) { if (ldq % 4 == 0) kernel_ctx.define_int("BLOCK_Q", 1); if (lda % 4 == 0 && v_full) kernel_ctx.define_int("BLOCK_A", 1); + if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1); kernel_ctx.define_int("REMAINDER_Q", (d->queries() % tile_q) != 0); } else if (pd()->arch() >= compute::gpu_arch_t::xe_hpc) { auto vbytes = d->values() * val_mdw.data_type_size();