From 5a56df218fe06768c1385513721686dfe512c7ad Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Sat, 21 Dec 2024 01:04:30 -0500 Subject: [PATCH] Add code to run sdpa on MPSGraph, this is for benchmark only. --- bin/nnc/sdpa_bench.c | 1 + ...ccv_nnc_scaled_dot_product_attention_mps.m | 52 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/bin/nnc/sdpa_bench.c b/bin/nnc/sdpa_bench.c index d7a4c02f2..5ca64a25f 100644 --- a/bin/nnc/sdpa_bench.c +++ b/bin/nnc/sdpa_bench.c @@ -16,6 +16,7 @@ static double get_current_time(void) int main(int argc, char** argv) { ccv_nnc_init(); + ccv_nnc_enable_flag(CCV_NNC_DISABLE_METAL_FLASH_ATTENTION); // Bypass error: variable-sized object may not be initialized #define num_trials 18 int B_candidates[num_trials] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; diff --git a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m index 47c30888b..5d726ab7e 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m +++ b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m @@ -148,8 +148,60 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c @autoreleasepool { ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); if (!ccv_nnc_mfa_context_supported(context) || (ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION)) { + MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context); + // Use MPSGraph. + ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size); + // Key will be consumed by the next method, therefore, no need to free. + int indices[3]; + int qstride[CCV_NNC_MAX_DIM_ALLOC]; + int kstride[CCV_NNC_MAX_DIM_ALLOC]; + int vstride[CCV_NNC_MAX_DIM_ALLOC]; + int ostride[CCV_NNC_MAX_DIM_ALLOC]; + ccv_nnc_tensor_view_get_stride(q, qstride); + ccv_nnc_tensor_view_get_stride(k, kstride); + ccv_nnc_tensor_view_get_stride(v, vstride); + ccv_nnc_tensor_view_get_stride(o, ostride); + int* qdim_r = qdim; + int* qstride_r = qstride; + int* kdim_r = kdim; + int* kstride_r = kstride; + int* vdim_r = vdim; + int* vstride_r = vstride; + const float scale = cmd.info.scaled_dot_product_attention.scale; + MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray* inputTensors, NSMutableArray* inputShapedTypes, NSMutableArray* resultTensors) { + MPSGraphTensor* mps_input_q; + MPSGraphTensor* mps_q = ccv_nnc_mps_graph_tensor_input(graph, q, qdim_r, qstride_r, &mps_input_q); + MPSGraphShapedType* mps_q_shape = ccv_nnc_mps_graph_tensor_input_shape(q, qdim_r, qstride_r); + MPSGraphTensor* mps_input_k; + MPSGraphTensor* mps_k = ccv_nnc_mps_graph_tensor_input(graph, k, kdim_r, kstride_r, &mps_input_k); + MPSGraphShapedType* mps_k_shape = ccv_nnc_mps_graph_tensor_input_shape(k, kdim_r, kstride_r); + MPSGraphTensor* mps_input_v; + MPSGraphTensor* mps_v = ccv_nnc_mps_graph_tensor_input(graph, v, vdim_r, vstride_r, &mps_input_v); + MPSGraphShapedType* mps_v_shape = ccv_nnc_mps_graph_tensor_input_shape(v, vdim_r, vstride_r); + [inputTensors addObject:mps_input_q]; + [inputShapedTypes addObject:mps_q_shape]; + [inputTensors addObject:mps_input_k]; + [inputShapedTypes addObject:mps_k_shape]; + [inputTensors addObject:mps_input_v]; + [inputShapedTypes addObject:mps_v_shape]; + mps_q = [graph transposeTensor:mps_q dimension:-3 withDimension:-2 name:nil]; + mps_k = [graph transposeTensor:mps_k dimension:-3 withDimension:-2 name:nil]; + mps_v = [graph transposeTensor:mps_v dimension:-3 withDimension:-2 name:nil]; + MPSGraphTensor* mps_o = [graph scaledDotProductAttentionWithQueryTensor:mps_q keyTensor:mps_k valueTensor:mps_v scale:scale name:nil]; + [resultTensors addObject:mps_o]; + [graph dump]; + }); + MPSGraphTensorData* data_q = ccv_nnc_mps_graph_tensor_data(q, qdim, qstride); + MPSGraphTensorData* data_k = ccv_nnc_mps_graph_tensor_data(k, kdim, kstride); + MPSGraphTensorData* data_v = ccv_nnc_mps_graph_tensor_data(v, vdim, vstride); + MPSGraphTensorData* data[] = {data_q, data_k, data_v}; + ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &o, (int*[]){ o->info.dim }, (int*[]){ o->stride }, 1, 0); + ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer); + return CCV_NNC_EXEC_SUCCESS; + /* assert(false); // MFA is required. return CCV_NNC_EXEC_INVALID; + */ } const int is_downcast = ((cmd.info.scaled_dot_product_attention.flags & CCV_NNC_GEMM_16F) && q->info.datatype == CCV_16F);