Skip to content

Commit

Permalink
Have a regression on when bias is not async loaded.
Browse files Browse the repository at this point in the history
Don't include BF16 headers if it is not used.
  • Loading branch information
liuliu committed Aug 9, 2024
1 parent e0c693c commit 3ef7843
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
3 changes: 3 additions & 0 deletions lib/nnc/mfa/v2/GEMMDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ std::pair<GEMMKernelDescriptor, PipelineValue<GEMMKernel> *> GEMMDescriptor::fin
// WARNING: The owner must explicitly retain the compute pipeline.
auto createPipeline =
[=](MTL::Library* library) -> MTL::ComputePipelineState* {
std::cout << "Pipeline cache miss." << std::endl;
// Set the function constants.
auto constants = NS::TransferPtr
(MTL::FunctionConstantValues::alloc()->init());
Expand All @@ -71,6 +72,7 @@ std::pair<GEMMKernelDescriptor, PipelineValue<GEMMKernel> *> GEMMDescriptor::fin
GEMMOperandPrecision registerPrecisionA = memoryPrecisions.A;
GEMMOperandPrecision registerPrecisionB = memoryPrecisions.B;
GEMMOperandPrecision registerPrecisionC = GEMMOperandPrecision::FP32;
GEMMOperandPrecision registerPrecisionBias = memoryPrecisions.bias;
if (memoryPrecisions.A == GEMMOperandPrecision::FP16 &&
memoryPrecisions.B == GEMMOperandPrecision::FP16 &&
memoryPrecisions.C == GEMMOperandPrecision::FP16) {
Expand Down Expand Up @@ -125,6 +127,7 @@ std::pair<GEMMKernelDescriptor, PipelineValue<GEMMKernel> *> GEMMDescriptor::fin
.A = registerPrecisionA,
.B = registerPrecisionB,
.C = registerPrecisionC,
.bias = registerPrecisionBias,
};

// Run a combinatorial search to find the correct value for
Expand Down
5 changes: 4 additions & 1 deletion lib/nnc/mfa/v2/GEMMHeaders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ namespace metal
)";
}

std::string createMetalSimdgroupMatrixStorage() {
std::string createMetalSimdgroupMatrixStorage(bool BF16) {
// How this header spawning code was designed.
//
// Find the patterns between the load/store functions:
Expand Down Expand Up @@ -639,6 +639,9 @@ namespace metal
for (auto action : actions) {
for (auto addressSpace : addressSpaces) {
for (auto decodingBF16 : decodingBF16s) {
if (!BF16 && decodingBF16) { // Don't need to output BF16 related methods.
continue;
}
desc.action = action;
desc.addressSpace = addressSpace;

Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/mfa/v2/GEMMHeaders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
std::string createMetalSimdgroupEvent();

/// Create the source code for the 'metal\_simdgroup\_matrix\_storage' header.
std::string createMetalSimdgroupMatrixStorage();
std::string createMetalSimdgroupMatrixStorage(bool BF16);

#endif /* GEMMHeaders_hpp */
5 changes: 3 additions & 2 deletions lib/nnc/mfa/v2/GEMMKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ GEMMKernel::GEMMKernel(GEMMKernelDescriptor descriptor, MTL::Device *const devic
// down execution speed on both M1/M2 and M3+.
CCV_NNC_MFA_PRECONDITION(false);
}
bool anyBF16 = (memoryPrecisions.A == GEMMOperandPrecision::BF16) || (memoryPrecisions.B == GEMMOperandPrecision::BF16) || (memoryPrecisions.C == GEMMOperandPrecision::BF16) || (memoryPrecisions.bias == GEMMOperandPrecision::BF16);

// Inject the contents of the headers.
source += createMetalSimdgroupEvent() + "\n";
source += createMetalSimdgroupMatrixStorage() + "\n";
source += createMetalSimdgroupMatrixStorage(anyBF16) + "\n";
source += "using namespace metal;\n";
source += "\n";

Expand Down Expand Up @@ -447,7 +448,7 @@ kernel void gemm(device MEMORY_NAME_A *A [[buffer(0)]],
)";

if (useBias) {
if (descriptor.preferAsyncStore) {
if (descriptor.preferAsyncLoad) {
source += "\n";
source += "#define USE_BIAS_ASYNC_COND false\n";
} else {
Expand Down

0 comments on commit 3ef7843

Please sign in to comment.