Skip to content

Commit

Permalink
Add a final fix for MFAv2 where we compute the gid ourselves.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 20, 2024
1 parent 05c73e6 commit 1cc009c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
12 changes: 3 additions & 9 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
}

MTL::Size gridSize
(ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]),
hash.Hq,
attentionDesc.batchDimension);
(ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1);
MTL::Size groupSize
(int64_t(kernel->threadgroupSize), 1, 1);

Expand Down Expand Up @@ -239,9 +237,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
}

MTL::Size backwardQueryGridSize
(ceilDivide(int64_t(hash.R), backwardQueryKernel->blockDimensions[0]),
hash.Hq,
attentionDesc.batchDimension);
(ceilDivide(int64_t(hash.R), backwardQueryKernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1);
MTL::Size backwardQueryGroupSize
(int64_t(backwardQueryKernel->threadgroupSize), 1, 1);

Expand Down Expand Up @@ -286,9 +282,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
}

MTL::Size backwardKeyValueGridSize
(ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]),
hash.Hq,
attentionDesc.batchDimension);
(ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1);
MTL::Size backwardKeyValueGroupSize
(int64_t(backwardKeyValueKernel->threadgroupSize), 1, 1);

Expand Down
14 changes: 13 additions & 1 deletion lib/nnc/mfa/v2/AttentionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,17 +427,29 @@ std::string AttentionKernel::createSource() const noexcept {
kernel void attention(
)";
source += createBufferBindings() + "\n";
switch (type.value) {
case AttentionKernelType::forward:
source.SetValue("DISPATCH_DIMENSION", "R");
break;
case AttentionKernelType::backwardQuery:
source.SetValue("DISPATCH_DIMENSION", "R");
break;
case AttentionKernelType::backwardKeyValue:
source.SetValue("DISPATCH_DIMENSION", "C");
break;
}
source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0]));
source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue());
source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue());
source += R"(
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
uint3 gid [[threadgroup_position_in_grid]],
uint gidx [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]]
) {
ushort2 morton_offset = morton_order(lane_id);
uint3 gid = { gidx % (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}}), (gidx / (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}})) % Hq, gidx / (Hq * (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}}))};
uint parallelization_group_offset = gid.x;
parallelization_group_offset *= {{BLOCK_DIMENSIONS_PARALLELIZATION}};
Expand Down

0 comments on commit 1cc009c

Please sign in to comment.