Skip to content

Commit

Permalink
fp32.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 23, 2024
1 parent e3013d7 commit 96252e9
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions bin/nnc/sdpa_bench.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,13 @@ int main(int argc, char** argv)

ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, Hq, R, D), 0);
// ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, is_causal), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor, NULL, NULL, NULL), TENSOR_LIST(o_tensor, NULL), 0);
ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, Hq, R, D), 0);
ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, Hk, C, D), 0);
ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, Hk, C, D), 0);
ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16), 0);

// Why it there 000 in the beginning of the argument list for GPU_TENSOR_NHWC?
ccv_nnc_tensor_t* const gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, Hq, R, D), 0);
ccv_nnc_tensor_t* const gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, Hk, C, D), 0);
ccv_nnc_tensor_t* const gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, Hk, C, D), 0);
ccv_nnc_tensor_t* const gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, Hq, R, D), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16), TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor), 0);
ccv_nnc_tensor_t* const gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, Hq, R, D), 0);
ccv_nnc_tensor_t* const gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, Hk, C, D), 0);
ccv_nnc_tensor_t* const gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, Hk, C, D), 0);
ccv_nnc_tensor_t* const gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, Hq, R, D), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor), 0);

ccv_nnc_cmd_t scaled_dot_product_attention = CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, is_causal);
scaled_dot_product_attention.info.scaled_dot_product_attention.flags = CCV_NNC_GEMM_16F;
Expand All @@ -75,23 +71,21 @@ int main(int argc, char** argv)
elapsed_time = get_current_time() - elapsed_time;
printf("%d, %d, %d, %d, %d, %d, %2.3f\n", B, R, C, Hq, Hk, D, elapsed_time);

ccv_nnc_tensor_t* const copy_of_gpu_o_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, Hq, R, D), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_o_tensor), TENSOR_LIST(copy_of_gpu_o_tensor_f16), 0);
ccv_nnc_tensor_t* const copy_of_gpu_o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, Hq, R, D), 0);
ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(copy_of_gpu_o_tensor_f16), TENSOR_LIST(copy_of_gpu_o_tensor), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_o_tensor), TENSOR_LIST(copy_of_gpu_o_tensor), 0);

// REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, copy_of_gpu_o_tensor->data.f32, o_tensor->data.f32, B * R * Hq * D, 3e-3, "GPU computed output should be the same as CPU computed ones");

ccv_nnc_tensor_free(o_tensor);
ccv_nnc_tensor_free(gpu_o_tensor);
ccv_nnc_tensor_free(copy_of_gpu_o_tensor);
ccv_nnc_tensor_free(copy_of_gpu_o_tensor_f16);
// ccv_nnc_tensor_free(copy_of_gpu_o_tensor_f16);
ccv_nnc_tensor_free(q_tensor);
ccv_nnc_tensor_free(k_tensor);
ccv_nnc_tensor_free(v_tensor);
ccv_nnc_tensor_free(q_tensor_f16);
ccv_nnc_tensor_free(k_tensor_f16);
ccv_nnc_tensor_free(v_tensor_f16);
// ccv_nnc_tensor_free(q_tensor_f16);
// ccv_nnc_tensor_free(k_tensor_f16);
// ccv_nnc_tensor_free(v_tensor_f16);
ccv_nnc_tensor_free(gpu_q_tensor);
ccv_nnc_tensor_free(gpu_k_tensor);
ccv_nnc_tensor_free(gpu_v_tensor);
Expand Down

0 comments on commit 96252e9

Please sign in to comment.