Skip to content

Commit

Permalink
Add flag to disable MFA GEMM.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 9, 2023
1 parent 8a11643 commit a2f75ed
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions lib/nnc/ccv_nnc.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum {
CCV_NNC_DISABLE_MIXED_MPS_SOFTMAX = 0x2,
CCV_NNC_DISABLE_MMAP_MTL_BUFFER = 0x4,
CCV_NNC_DISABLE_METAL_FLASH_ATTENTION = 0x8,
CCV_NNC_DISABLE_MFA_GEMM = 0x16,
};
/**
* Enable system-wide specific flag.
Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION);
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM);

size_t a_data_size = 0;
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
Expand Down

0 comments on commit a2f75ed

Please sign in to comment.