diff --git a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp index 87ebb9c4e..d345e893b 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp @@ -148,9 +148,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p } MTL::Size gridSize - (ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]), - hash.Hq, - attentionDesc.batchDimension); + (ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1); MTL::Size groupSize (int64_t(kernel->threadgroupSize), 1, 1); @@ -239,9 +237,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p } MTL::Size backwardQueryGridSize - (ceilDivide(int64_t(hash.R), backwardQueryKernel->blockDimensions[0]), - hash.Hq, - attentionDesc.batchDimension); + (ceilDivide(int64_t(hash.R), backwardQueryKernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1); MTL::Size backwardQueryGroupSize (int64_t(backwardQueryKernel->threadgroupSize), 1, 1); @@ -286,9 +282,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p } MTL::Size backwardKeyValueGridSize - (ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]), - hash.Hq, - attentionDesc.batchDimension); + (ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1); MTL::Size backwardKeyValueGroupSize (int64_t(backwardKeyValueKernel->threadgroupSize), 1, 1); diff --git a/lib/nnc/mfa/v2/AttentionKernel.cpp b/lib/nnc/mfa/v2/AttentionKernel.cpp index fe749b509..8297c6fb1 100644 --- a/lib/nnc/mfa/v2/AttentionKernel.cpp +++ b/lib/nnc/mfa/v2/AttentionKernel.cpp @@ -427,17 +427,29 @@ std::string AttentionKernel::createSource() const noexcept { kernel void attention( )"; source += createBufferBindings() + "\n"; + switch (type.value) { + case AttentionKernelType::forward: + source.SetValue("DISPATCH_DIMENSION", "R"); + break; + case AttentionKernelType::backwardQuery: + source.SetValue("DISPATCH_DIMENSION", "R"); + break; + case AttentionKernelType::backwardKeyValue: + source.SetValue("DISPATCH_DIMENSION", "C"); + break; + } source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); source += R"( threadgroup uchar *threadgroup_block [[threadgroup(0)]], - uint3 gid [[threadgroup_position_in_grid]], + uint gidx [[threadgroup_position_in_grid]], ushort sidx [[simdgroup_index_in_threadgroup]], ushort lane_id [[thread_index_in_simdgroup]] ) { ushort2 morton_offset = morton_order(lane_id); + uint3 gid = { gidx % (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}}), (gidx / (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}})) % Hq, gidx / (Hq * (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}}))}; uint parallelization_group_offset = gid.x; parallelization_group_offset *= {{BLOCK_DIMENSIONS_PARALLELIZATION}};