Skip to content

Commit

Permalink
Add code to run sdpa on MPSGraph, this is for benchmark only.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 21, 2024
1 parent 85efcf8 commit 5a56df2
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions bin/nnc/sdpa_bench.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ static double get_current_time(void)
int main(int argc, char** argv)
{
ccv_nnc_init();
ccv_nnc_enable_flag(CCV_NNC_DISABLE_METAL_FLASH_ATTENTION);
// Bypass error: variable-sized object may not be initialized
#define num_trials 18
int B_candidates[num_trials] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,60 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
@autoreleasepool {
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
if (!ccv_nnc_mfa_context_supported(context) || (ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION)) {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
// Use MPSGraph.
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
// Key will be consumed by the next method, therefore, no need to free.
int indices[3];
int qstride[CCV_NNC_MAX_DIM_ALLOC];
int kstride[CCV_NNC_MAX_DIM_ALLOC];
int vstride[CCV_NNC_MAX_DIM_ALLOC];
int ostride[CCV_NNC_MAX_DIM_ALLOC];
ccv_nnc_tensor_view_get_stride(q, qstride);
ccv_nnc_tensor_view_get_stride(k, kstride);
ccv_nnc_tensor_view_get_stride(v, vstride);
ccv_nnc_tensor_view_get_stride(o, ostride);
int* qdim_r = qdim;
int* qstride_r = qstride;
int* kdim_r = kdim;
int* kstride_r = kstride;
int* vdim_r = vdim;
int* vstride_r = vstride;
const float scale = cmd.info.scaled_dot_product_attention.scale;
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_q;
MPSGraphTensor* mps_q = ccv_nnc_mps_graph_tensor_input(graph, q, qdim_r, qstride_r, &mps_input_q);
MPSGraphShapedType* mps_q_shape = ccv_nnc_mps_graph_tensor_input_shape(q, qdim_r, qstride_r);
MPSGraphTensor* mps_input_k;
MPSGraphTensor* mps_k = ccv_nnc_mps_graph_tensor_input(graph, k, kdim_r, kstride_r, &mps_input_k);
MPSGraphShapedType* mps_k_shape = ccv_nnc_mps_graph_tensor_input_shape(k, kdim_r, kstride_r);
MPSGraphTensor* mps_input_v;
MPSGraphTensor* mps_v = ccv_nnc_mps_graph_tensor_input(graph, v, vdim_r, vstride_r, &mps_input_v);
MPSGraphShapedType* mps_v_shape = ccv_nnc_mps_graph_tensor_input_shape(v, vdim_r, vstride_r);
[inputTensors addObject:mps_input_q];
[inputShapedTypes addObject:mps_q_shape];
[inputTensors addObject:mps_input_k];
[inputShapedTypes addObject:mps_k_shape];
[inputTensors addObject:mps_input_v];
[inputShapedTypes addObject:mps_v_shape];
mps_q = [graph transposeTensor:mps_q dimension:-3 withDimension:-2 name:nil];
mps_k = [graph transposeTensor:mps_k dimension:-3 withDimension:-2 name:nil];
mps_v = [graph transposeTensor:mps_v dimension:-3 withDimension:-2 name:nil];
MPSGraphTensor* mps_o = [graph scaledDotProductAttentionWithQueryTensor:mps_q keyTensor:mps_k valueTensor:mps_v scale:scale name:nil];
[resultTensors addObject:mps_o];
[graph dump];
});
MPSGraphTensorData* data_q = ccv_nnc_mps_graph_tensor_data(q, qdim, qstride);
MPSGraphTensorData* data_k = ccv_nnc_mps_graph_tensor_data(k, kdim, kstride);
MPSGraphTensorData* data_v = ccv_nnc_mps_graph_tensor_data(v, vdim, vstride);
MPSGraphTensorData* data[] = {data_q, data_k, data_v};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &o, (int*[]){ o->info.dim }, (int*[]){ o->stride }, 1, 0);
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
return CCV_NNC_EXEC_SUCCESS;
/*
assert(false); // MFA is required.
return CCV_NNC_EXEC_INVALID;
*/
}

const int is_downcast = ((cmd.info.scaled_dot_product_attention.flags & CCV_NNC_GEMM_16F) && q->info.datatype == CCV_16F);
Expand Down

0 comments on commit 5a56df2

Please sign in to comment.