Skip to content

Commit

Permalink
use 'malloc_device' instead of 'malloc_shared', got 307 Tflops (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
taozha2 authored Apr 24, 2024
1 parent 17db8d0 commit 303daa6
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 53 deletions.
5 changes: 3 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
sycl_compiler_path=/opt/cutlass_compiler/
target=./examples/cute/tutorial/pvc_sycl
cuda_path=/usr/local/cuda-12.3/
mkl_path=/opt/intel/oneapi/mkl/2024.1
rm -rf $target
export ZE_AFFINITY_MASK=0
export CPATH=$sycl_compiler_path:$sycl_compiler_path/include/:$sycl_compiler_path/include/sycl/:/opt/intel/oneapi/mkl/2024.1/include/
export CPATH=$sycl_compiler_path:$sycl_compiler_path/include/:$sycl_compiler_path/include/sycl/:$mkl_path/include/
export LIBRARY_PATH=/opt/intel/oneapi/mkl/2024.1/lib/
export LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/2024.1/lib/:${sycl_compiler_path}/lib/
export LD_LIBRARY_PATH=$mkl_path/lib/:${sycl_compiler_path}/lib/
export IGC_EnableVISANoSchedule=1
export IGC_ShaderDumpEnable=1
export IGC_DumpToCustomDir=./mm_dumps_prefetch_coop
Expand Down
29 changes: 14 additions & 15 deletions examples/cute/tutorial/pvc_sycl/pvc_prefetch_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ void HELPER_NAME(btile_prefetch_vnni, MM, NN)(global ushort *B, int tN, int N,

void HELPER_NAME(atile_block_prefetch_rowmajor, MM,
NN)(global ushort *A, int tM, int M, int K, int m, int k) {
if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) {
const int sg_index_x =
if constexpr (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) {
const unsigned int sg_index_x =
get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X)
const int kk = 0;
const unsigned int kk = 0;
const int mm = sg_index_x % 4;
// if (get_sub_group_local_id() == 0) {
// printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k =
Expand All @@ -49,33 +49,32 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM,
// + kk * tK, m + mm * tM);
// }
intel_subgroup_block_prefetch_u16_m8k16v2(
A, K * sizeof(ushort), M, K * sizeof(ushort),
(int2_){k + kk * tK, m + mm * tM});
} else if (KK % 2 == 0 & MM % 4 == 0) {
A, K * sizeof(ushort), M, K * sizeof(ushort), (int2_){k, m + mm * tM});
} else if constexpr (KK % 2 == 0 & MM % 4 == 0) {
for (int kk = 0; kk < KK; kk += 2) {
for (int mm = 0; mm < MM; mm += 4) {
intel_subgroup_block_prefetch_u16_m32k16v2(
A, K * sizeof(ushort), M, K * sizeof(ushort),
(int2_){k + kk * tK, m + mm * tM});
}
}
} else if (KK % 2 == 0 & MM % 2 == 0) {
} else if constexpr (KK % 2 == 0 & MM % 2 == 0) {
for (int kk = 0; kk < KK; kk += 2) {
for (int mm = 0; mm < MM; mm += 2) {
intel_subgroup_block_prefetch_u16_m16k16v2(
A, K * sizeof(ushort), M, K * sizeof(ushort),
(int2_){k + kk * tK, m + mm * tM});
}
}
} else if (KK % 2 == 0) {
} else if constexpr (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk += 2) {
for (int mm = 0; mm < MM; mm++) {
intel_subgroup_block_prefetch_u16_m8k16v2(
A, K * sizeof(ushort), M, K * sizeof(ushort),
(int2_){k + kk * tK, m + mm * tM});
}
}
} else if (MM % 4 == 0) {
} else if constexpr (MM % 4 == 0) {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm += 4) {
intel_subgroup_block_prefetch_u16_m32k16(
Expand Down Expand Up @@ -149,15 +148,15 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM,

void HELPER_NAME(btile_block_prefetch_vnni, MM,
NN)(global ushort *B, int tN, int K, int N, int k, int n) {
if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) {
const int sg_index_y =
if constexpr (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) {
const unsigned int sg_index_y =
get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y)
const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3
const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0
// static const unsigned int kk = 0; // kk(sg_index_y) == 0, 0,
// 0, 0, 0, 0, 0, 0
intel_subgroup_block_prefetch_u32_m16k16(
B, N * sizeof(uint), K, N * sizeof(uint),
(int2_){n + nn * tN, (k + kk * tK) / 2});
} else if (KK % 2 == 0) {
B, N * sizeof(uint), K, N * sizeof(uint), (int2_){n + nn * tN, k / 2});
} else if constexpr (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk += 2) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_prefetch_u32_m16k16(
Expand Down
79 changes: 43 additions & 36 deletions examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@ void check_results(size_t M, size_t N, const T *C, const T *C_ref) {
}
}

auto fail_rate = (float)error_cnt * 100 / (M * N);
auto pass_rate = (1.f - ((float)error_cnt / (M * N))) * 100; // %

std::cout << "\n\n==== fail points %d is: " << fail_rate << "% !!!\n"
<< std::endl;
std::cout << "\n\n==== Pass rate is: " << pass_rate << "% !!!\n" << std::endl;
}

inline size_t time_event(sycl::event &e) {
Expand Down Expand Up @@ -136,8 +135,6 @@ static void go_dpas_blockread_vnni_tiled(sycl::queue queue, dtype_acc *C,
ev = queue.submit([&](sycl::handler &cgh) {
cgh.parallel_for(nd_range, [=](sycl::nd_item<2>
id) [[sycl::reqd_sub_group_size(16)]] {
const int M = id.get_global_range(0) * ITEM_SIZE_Y;
const int N = id.get_global_range(1) * ITEM_SIZE_X;
const int m = id.get_group(0) * WG_SIZE_Y +
(get_sub_group_id() / SGS_PER_WG_X) * SG_SIZE_Y;
const int n = id.get_group(1) * WG_SIZE_X +
Expand Down Expand Up @@ -238,63 +235,73 @@ static void go_dpas_blockread_vnni_tiled(sycl::queue queue, dtype_acc *C,

printf("Checking results... ");
fflush(stdout);
check_results(M, N, C, C_ref);

dtype_c *C_host = (dtype_c *)syclcompat::malloc_host(sizeof(float) * M * N);
queue.memcpy(C_host, C, M * N * sizeof(dtype_c));
check_results(M, N, C_host, C_ref);

free(C_host, queue);
printf(" done!\n");
}

int main(int argc, char **argv) {
sycl::queue queue{{sycl::property::queue::enable_profiling()}};
auto queue = sycl::queue{{sycl::property::queue::enable_profiling()}};
auto context = queue.get_info<sycl::info::queue::context>();
auto device = queue.get_info<sycl::info::queue::device>();

const auto M = matrixSize;
const auto N = matrixSize;
const auto K = matrixSize;

dtype_a *A_vec =
(dtype_a *)syclcompat::malloc_shared(sizeof(dtype_a) * M * K);
dtype_b *B_vec =
(dtype_b *)syclcompat::malloc_shared(sizeof(dtype_b) * N * K);
dtype_b *Bvnni_vec =
(dtype_b *)syclcompat::malloc_shared(sizeof(dtype_b) * N * K);
dtype_acc *C_vec =
(dtype_acc *)syclcompat::malloc_shared(sizeof(dtype_acc) * M * N);
dtype_acc *C_ref =
(dtype_acc *)syclcompat::malloc_shared(sizeof(dtype_acc) * M * N);
dtype_a *A_host = (dtype_a *)syclcompat::malloc_host(sizeof(dtype_a) * M * K);
dtype_b *B_host = (dtype_b *)syclcompat::malloc_host(sizeof(dtype_b) * N * K);
dtype_b *Bvnni_host =
(dtype_b *)syclcompat::malloc_host(sizeof(dtype_b) * N * K);
dtype_acc *C_host =
(dtype_acc *)syclcompat::malloc_host(sizeof(dtype_c) * M * N);
dtype_acc *C_ref_host =
(dtype_acc *)syclcompat::malloc_host(sizeof(dtype_acc) * M * N);

dtype_a *A_dev =
(dtype_a *)sycl::malloc_device(sizeof(dtype_a) * M * K, device, context);
dtype_b *B_dev =
(dtype_b *)sycl::malloc_device(sizeof(dtype_b) * N * K, device, context);
dtype_acc *C_dev = (dtype_acc *)sycl::malloc_device(sizeof(dtype_c) * M * N,
device, context);

printf("Initializing source matrices...\n");
fill_matrix(A_vec, M, K);
fill_matrix(B_vec, K, N);
fill_matrix(A_host, M, K);
fill_matrix(B_host, K, N);
vnni_matrix(Bvnni_host, B_host, K, N, 2);

vnni_matrix(Bvnni_vec, B_vec, K, N, 2);
queue.memcpy(A_dev, A_host, sizeof(dtype_a) * M * K);
queue.memcpy(B_dev, Bvnni_host, sizeof(dtype_b) * N * K);
queue.memcpy(C_dev, C_host, sizeof(dtype_c) * M * N);

printf("Computing reference...\n");
get_gemm_gold<dtype_a, dtype_b, dtype_acc>(
M, N, K, mem_layout::row_major, mem_layout::row_major, (dtype_a *)A_vec,
(dtype_b *)B_vec, (dtype_acc *)C_ref);

printf("Creating source buffers...\n");
auto A = A_vec;
auto B = B_vec;
auto Bvnni = Bvnni_vec;
M, N, K, mem_layout::row_major, mem_layout::row_major, (dtype_a *)A_host,
(dtype_b *)B_host, (dtype_c *)C_ref_host);

printf("Running gemm tests, MKN: (%d, %d, %d)...\n", M, K, N);

#ifdef B_VNNI
go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(queue, C_vec, A, Bvnni, M, N, K,
C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(queue, C_dev, A_dev, B_dev, M,
N, K, C_ref_host);
#else
go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(queue, C_vec, A, B, M, N, K,
C_ref);
// TODO:
#endif

printf("Done.\n");

free(A_vec, queue);
free(B_vec, queue);
free(C_vec, queue);
free(Bvnni_vec, queue);
free(C_ref, queue);
free(A_host, queue);
free(B_host, queue);
free(C_host, queue);
free(Bvnni_host, queue);
free(C_ref_host, queue);
free(A_dev, queue);
free(B_dev, queue);
free(C_dev, queue);

return 0;
}

0 comments on commit 303daa6

Please sign in to comment.