diff --git a/lib/nnc/ccv_nnc.h b/lib/nnc/ccv_nnc.h index 7505cb6ef..46999d53b 100644 --- a/lib/nnc/ccv_nnc.h +++ b/lib/nnc/ccv_nnc.h @@ -32,6 +32,7 @@ enum { CCV_NNC_DISABLE_MIXED_MPS_SOFTMAX = 0x2, CCV_NNC_DISABLE_MMAP_MTL_BUFFER = 0x4, CCV_NNC_DISABLE_METAL_FLASH_ATTENTION = 0x8, + CCV_NNC_DISABLE_MFA_GEMM = 0x16, }; /** * Enable system-wide specific flag. diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m index 46f21bb9c..8e79407f1 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m @@ -176,7 +176,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); const int is_mfa_supported = - ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION); + ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM); size_t a_data_size = 0; if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)