diff --git a/examples/cute/tutorial/pvc_sycl.cpp b/examples/cute/tutorial/pvc_sycl.cpp deleted file mode 100644 index ff1b0a670..000000000 --- a/examples/cute/tutorial/pvc_sycl.cpp +++ /dev/null @@ -1,295 +0,0 @@ -/* -// Copyright (c) 2019-2024 Ben Ashbaugh -// -// SPDX-License-Identifier: MIT -*/ - -#include - -#include -#include -#include -#include -#include -#include - -#include -#include - -using test_clock = std::chrono::high_resolution_clock; - -using sycl::ext::oneapi::bfloat16; - -bool identityData = false; -bool fixedData = false; -bool validate = true; -int testIterations = 16; -float threshold = 0.01f; -size_t matrixSize = 512; - -#define WARMUP_ITERATIONS 100 - -std::string makeTestName( - const std::string &func, - int tM, int tN, int tK, - int MM, int NN, - size_t M, size_t N, size_t K) -{ - std::ostringstream ret; - ret << func; - ret << ""; - ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; - return ret.str(); -} - -template -static void fill_matrix(std::vector &M, size_t numRows, size_t numCols) -{ - if (identityData) - { - std::generate(std::begin(M), std::end(M), [&] - { return 1.0f; }); - } - else if (fixedData) - { - for (size_t r = 0; r < numRows; r++) - { - for (size_t c = 0; c < numCols; c++) - { - M[r * numCols + c] = static_cast(r + c); - } - } - } - else - { - std::random_device dev; - std::mt19937 rng(dev()); - std::uniform_real_distribution dist(-1.0, 1.0); - std::generate(std::begin(M), std::end(M), [&] - { return dist(rng); }); - } -} - -template -static void vnni_matrix( - std::vector &dst, const std::vector &src, - size_t numRows, size_t numCols, size_t factor) -{ - for (size_t r = 0; r < numRows / factor; r++) - { - for (size_t c = 0; c < numCols; c++) - { - for (size_t k = 0; k < factor; k++) - { - dst[r * numCols * factor + c * factor + k] = - src[(r * factor + k) * numCols + c]; - } - } - } -} - -template -static void compute_reference( - std::vector &C, - const std::vector &A, const std::vector &B, - size_t M, size_t N, size_t K) -{ - for (size_t m = 0; m < M; m++) - { - for (size_t n = 0; n < N; n++) - { - DstT sum = 0; - for (size_t k = 0; k < K; k++) - { - sum = std::fma(static_cast(A[m * K + k]), - static_cast(B[k * N + n]), sum); - } - C[m * N + n] = sum; - } - } -} - -template -void check_results( - size_t M, - size_t N, - const std::vector &C, - const std::vector &C_ref) -{ - float err = 0.f; - for (size_t m = 0; m < M; m++) - { - for (size_t n = 0; n < N; n++) - { - auto index = m * N + n; - auto localErr = std::fabs(C[index] - C_ref[index]) / - std::max(std::fabs(C[index]), - std::fabs(C_ref[index])); - err = std::max(localErr, err); - if (localErr >= threshold) - { - std::cerr << "Error at m = " << m << ", n = " << n - << ": (local error " << localErr << "): Wanted " - << C_ref[index] << ", got " << C[index] << std::endl; - // return; - } - } - } -} - -inline size_t time_event(sycl::event &e) -{ - // get start and end times - cl_ulong start_time = - e.template get_profiling_info(); - - cl_ulong end_time = - e.template get_profiling_info(); - - // return the delta - return static_cast(end_time - start_time); -} - -template -static void go_dpas_blockread_vnni_tiled( - sycl::queue queue, - std::vector &c_vec, sycl::buffer a, sycl::buffer b, - size_t M, size_t N, size_t K, - const std::vector &C_ref) -{ - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); - fflush(stdout); - - int total_iterations = WARMUP_ITERATIONS + testIterations; - if (tM * MM > M) - { - printf("M is too small.\n"); - } - else if (tN * NN > N) - { - printf("N is too small.\n"); - } - else - { - float best = 999.0f; - std::vector event_times(total_iterations); - for (int test = 0; test < total_iterations; test++) - { - sycl::buffer c{c_vec}; - sycl::event ev; - ev = queue.submit([&](sycl::handler &cgh) - { - sycl::accessor accA { a, cgh, sycl::read_only }; - sycl::accessor accB { b, cgh, sycl::read_only }; - sycl::accessor accC { c, cgh, sycl::write_only }; - cgh.parallel_for/*>*/(sycl::nd_range<2>{{ M/tM/MM, N/NN }, { 1, 16}}, - [=](sycl::nd_item<2> id) [[sycl::reqd_sub_group_size(16)]] { - const int M = id.get_global_range(0) * tM * MM; - const int N = id.get_global_range(1) * NN; - const int m = id.get_group().get_group_id(0) * tM * MM; - const int n = id.get_group().get_group_id(1) * tN * NN; - - auto A = accA.get_multi_ptr().get(); - auto B = accB.get_multi_ptr().get(); - auto C = accC.get_multi_ptr().get(); - - - using namespace cute; - - Tensor tAr = make_tensor(Shape<_8, Int>{}); - Tensor tBr = make_tensor(Shape<_8, Int>{}); - Tensor tCr = make_tensor(Shape<_8, Int, Int>{}); - - auto A_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(A), make_shape(M, K))); - auto B_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(B), make_shape(K, N))); - auto C_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(C), make_shape(M, N))); - //TODO: - decide on how to deal with vector types - // - create layouts with tiling/partitioning - - Tensor tAi = make_tensor(make_inttuple_iter(m, 0), make_layout(make_shape(_1{}, Int{}, K), make_stride(_1{}, tM*E<0>{}, E<1>{}))); - Tensor tBi = make_tensor(make_inttuple_iter(0, n), make_layout(make_shape(_1{}, K, Int{}), make_stride(_1{}, E<0>{}, tN*E<1>{}))); - Tensor tCi = make_tensor(make_inttuple_iter(m, n), make_layout(Shape<_1, Int, Int>{}, make_stride(_1{}, tM*E<0>{}, tN*E<1>{}))); - TiledMMA, Layout>> tiled_mma; - - for (int k = 0; k < K; k += tK) { - copy(A_copy, tAi(_, _, k), tAr); - copy(B_copy, tBi(_, k/2, _), tBr); - gemm(tiled_mma, tAr, tBr, tCr); - } - copy(C_copy, tCr, tCi); - -}); }); - - ev.wait_and_throw(); - event_times[test] = time_event(ev); - } - - double average_event_time = 0.f; - for (int i = WARMUP_ITERATIONS; i < total_iterations; i++) - { - average_event_time += event_times[i]; - } - average_event_time /= (testIterations * 1e3); - auto gops = 2.0 * M * N * K; - printf("Average is %f microseconds (%f gops)\n", average_event_time, gops / (1e3 * average_event_time)); - - if (validate) - { - printf("Checking results... "); - fflush(stdout); - check_results(M, N, c_vec, C_ref); - printf(" done!\n"); - } - } -} - -int main(int argc, char **argv) -{ - printf("Config:\n"); - printf("\tTest Iterations: %d\n", testIterations); - printf("\tValidating data?: %s\n", validate ? "true" : "false"); - printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); - - sycl::queue queue{{sycl::property::queue::enable_profiling()}}; - - const auto M = matrixSize; - const auto N = matrixSize; - const auto K = matrixSize; - - std::vector A_vec(M * K); - std::vector B_vec(K * N); - std::vector Bvnni_vec(K * N); - std::vector C_vec(M * N); - std::vector C_ref(M * N); - - printf("Initializing source matrices...\n"); - fill_matrix(A_vec, M, K); - fill_matrix(B_vec, K, N); - - vnni_matrix(Bvnni_vec, B_vec, K, N, 2); - - if (validate) - { - printf("Computing reference...\n"); - compute_reference(C_ref, A_vec, B_vec, M, N, K); - } - - printf("Creating source buffers...\n"); - sycl::buffer A{A_vec}; - sycl::buffer B{B_vec}; - sycl::buffer Bvnni{Bvnni_vec}; - - printf("Running tests...\n"); - - go_dpas_blockread_vnni_tiled<8, 16, 16, 1, 1>(queue, C_vec, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 1>(queue, C_vec, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 1, 2>(queue, C_vec, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 2>(queue, C_vec, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 2>(queue, C_vec, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 4>(queue, C_vec, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(queue, C_vec, A, Bvnni, M, N, K, C_ref); - - printf("Done.\n"); - - return 0; -} diff --git a/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp b/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp index c25ab28a2..79e06bc8a 100644 --- a/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp +++ b/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp @@ -52,6 +52,19 @@ static void fill_matrix(T *M, size_t numRows, size_t numCols) { } } +template +static void fill_matrix_B(T *M, size_t numRows, size_t numCols) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = bfloat16_t(0.0f); + } + }; + + for (size_t r = 0; r < numRows; r++) { + M[r * numCols + r] = bfloat16_t(1.0f); + } +} + template static void vnni_matrix(T *dst, const T *src, size_t numRows, size_t numCols, size_t factor) { @@ -77,8 +90,8 @@ void check_results(size_t M, size_t N, const T *C, const T *C_ref) { err = std::max(localErr, err); if (localErr >= threshold) { error_cnt++; - // std::cerr << "Error at m = " << m << ", n = " << n << ": (local error - // " + // std::cerr << "Error at m = " << m << ", n = " << n << ": (local + // error" // << localErr << "): Wanted " << C_ref[index] << ", got " // << C[index] << std::endl; // return; @@ -121,93 +134,93 @@ static void go_dpas_blockread_vnni_tiled(sycl::queue queue, dtype_acc *C, for (int test = 0; test < total_iterations; test++) { sycl::event ev; ev = queue.submit([&](sycl::handler &cgh) { - cgh.parallel_for( - nd_range, [=](sycl::nd_item<2> id) [[sycl::reqd_sub_group_size(16)]] { - const int m = id.get_group(0) * WG_SIZE_Y + - (get_sub_group_id() / SGS_PER_WG_X) * SG_SIZE_Y; - const int n = id.get_group(1) * WG_SIZE_X + - (get_sub_group_id() % SGS_PER_WG_X) * SG_SIZE_X; - - float8 sum[NN][MM]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = 0; - } - } - - int prefetch_k = 0; + cgh.parallel_for(nd_range, [=](sycl::nd_item<2> + id) [[sycl::reqd_sub_group_size(16)]] { + const int M = id.get_global_range(0) * ITEM_SIZE_Y; + const int N = id.get_global_range(1) * ITEM_SIZE_X; + const int m = id.get_group(0) * WG_SIZE_Y + + (get_sub_group_id() / SGS_PER_WG_X) * SG_SIZE_Y; + const int n = id.get_group(1) * WG_SIZE_X + + (get_sub_group_id() % SGS_PER_WG_X) * SG_SIZE_X; + + Tensor tAr = make_tensor(Shape, Int<1>>{}); + Tensor tBr = make_tensor(Shape, Int>{}); + Tensor tCr = make_tensor(Shape<_8, Int, Int>{}); + + auto A_copy = + make_xe_2d_A_copy(make_tensor(make_gmem_ptr(A), make_shape(M, K))); + auto B_copy = + make_xe_2d_B_copy(make_tensor(make_gmem_ptr(B), make_shape(K, N))); + auto C_copy = + make_xe_2d_copy(make_tensor(make_gmem_ptr(C), make_shape(M, N))); + // TODO: - decide on how to deal with vector types + // - create layouts with tiling/partitioning + + Tensor tAi = make_tensor( + make_inttuple_iter(m, 0), + make_layout(make_shape(_1{}, _1{}, K), + make_stride(_1{}, MM * tM * E<0>{}, E<1>{}))); + Tensor tBi = + make_tensor(make_inttuple_iter(0, n), + make_layout(make_shape(_1{}, K, Int{}), + make_stride(_1{}, E<0>{}, tN * E<1>{}))); + Tensor tCi = make_tensor( + make_inttuple_iter(m, n), + make_layout(Shape<_1, Int, Int>{}, + make_stride(_1{}, tM * E<0>{}, tN * E<1>{}))); + TiledMMA, + Layout>> + tiled_mma; + + int prefetch_k = 0; #ifdef PREFETCH_DEFAULT - for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int p = 0; p < PREFETCH_DISTANCE; p++) { #ifdef B_VNNI - HELPER_NAME(btile_block_prefetch_vnni, 4, 4) - ((ushort *)B, tN, K, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_vnni, 4, 4) + ((ushort *)B, tN, K, N, prefetch_k, n); #else HELPER_NAME(btile_block_prefetch_rowmajor, 4, 4) ((ushort *)B, tN, K, N, prefetch_k, n); #endif - HELPER_NAME(atile_block_prefetch_rowmajor, 4, 4) - ((ushort *)A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; - } + HELPER_NAME(atile_block_prefetch_rowmajor, 4, 4) + ((ushort *)A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } #endif - split_barrier_arrive(); - - for (int k = 0; k < K; k += tK * KK) { - short8 aData[2][4]; - int8 bData[4][2]; + for (int k = 0; k < K; k += tK * KK) { + copy(A_copy, tAi(_, _, k), tAr); + copy(B_copy, tBi(_, k / 2, _), tBr); #ifdef PREFETCH_DEFAULT + for (int p = 0; p < PREFETCH_DISTANCE; p++) { #ifdef B_VNNI - HELPER_NAME(btile_block_prefetch_vnni, 4, 4) - ((ushort *)B, tN, K, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_vnni, 4, 4) + ((ushort *)B, tN, K, N, prefetch_k, n); #else HELPER_NAME(btile_block_prefetch_rowmajor, 4, 4) ((ushort *)B, tN, K, N, prefetch_k, n); #endif - HELPER_NAME(atile_block_prefetch_rowmajor, 4, 4) - ((ushort *)A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; + HELPER_NAME(atile_block_prefetch_rowmajor, 4, 4) + ((ushort *)A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } #endif - - *(ushort64 *)(&aData) = - __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( - (long)A, K * sizeof(ushort) - 1, M - 1, - K * sizeof(ushort) - 1, int2_{k, m}); - - for (int i = 0; i < NN; i++) { - *(uint16 *)(&bData[i][0]) = - __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( - (long)B, N * sizeof(uint) - 1, K - 1, - N * sizeof(uint) - 1, int2_{n + i * tN, k / 2}); - } - - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[nn][mm] = intel_sub_group_bf16_bf16_matrix_mad_k16( - aData[kk][mm], bData[nn][kk], sum[nn][mm]); - } - } - } - split_barrier_wait(); - split_barrier_arrive(); - } - split_barrier_wait(); - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( - (long)C, N * sizeof(float) - 1, M - 1, - N * sizeof(float) - 1, int2_{n + nn * tN, m + mm * tM}, - sycl::bit_cast(sum[nn][mm])); - } - } - }); + auto tAr_view = make_tensor(static_cast(tAr).data(), + Shape<_8, Int, Int>{}); + auto tBr_view = make_tensor(static_cast(tBr).data(), + Shape<_16, Int, Int>{}); + for (int kl = 0; kl < KK; kl++) { + gemm(tiled_mma, tAr_view(_, _, kl), tBr_view(_, kl, _), tCr); + } + } + + copy(C_copy, tCr, tCi); + }); }); ev.wait_and_throw(); - event_times[test] = time_event(ev) / 1e6; // ms + event_times[test] = time_event(ev) / 1e6; } double average_event_time = 0.f; diff --git a/examples/cute/tutorial/pvc_sycl/pvc_sycl_builtins.hpp b/examples/cute/tutorial/pvc_sycl/pvc_sycl_builtins.hpp index e6294c990..c814d397e 100644 --- a/examples/cute/tutorial/pvc_sycl/pvc_sycl_builtins.hpp +++ b/examples/cute/tutorial/pvc_sycl/pvc_sycl_builtins.hpp @@ -70,9 +70,9 @@ enum LSC_LDCC { LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached }; -typedef ushort __attribute__((ext_vector_type(32))) ushort32; -typedef uint __attribute__((ext_vector_type(32))) uint32; -typedef ushort __attribute__((ext_vector_type(64))) ushort64; +// typedef ushort __attribute__((ext_vector_type(32))) ushort32; +// typedef uint __attribute__((ext_vector_type(32))) uint32; +// typedef ushort __attribute__((ext_vector_type(64))) ushort64; typedef uint __attribute__((ext_vector_type(16))) uint16; typedef uint __attribute__((ext_vector_type(8))) uint8; typedef int __attribute__((ext_vector_type(8))) int8; @@ -114,12 +114,6 @@ SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1( SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord, enum LSC_LDCC cache_control)); -SYCL_DEVICE_BUILTIN(ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, int2_ coord)); -SYCL_DEVICE_BUILTIN(uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, int2_ coord)); SYCL_DEVICE_BUILTIN(void __builtin_IB_work_group_barrier_arrive(uint flags)); SYCL_DEVICE_BUILTIN(void __builtin_IB_work_group_barrier_wait(uint flags)); diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index aaf956e50..98e38e2ae 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -1,45 +1,188 @@ #pragma once -#include +#include #include +#include #include #ifdef __SYCL_DEVICE_ONLY__ #define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x #else -#define SYCL_DEVICE_BUILTIN(x) \ - inline x { assert(false); } +#define SYCL_DEVICE_BUILTIN(x) \ + inline x { assert(false); } #endif -SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord, uint8 data)); -SYCL_DEVICE_BUILTIN(ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord)); -SYCL_DEVICE_BUILTIN(uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, int2_ coord, uint8 data)); +SYCL_DEVICE_BUILTIN(ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, int2_ coord)); +SYCL_DEVICE_BUILTIN(uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, int2_ coord)); + +/// Load A +SYCL_DEVICE_BUILTIN(ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, int2_ coord)); +SYCL_DEVICE_BUILTIN(ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, int2_ coord)); +SYCL_DEVICE_BUILTIN(ushort16 intel_subgroup_block_read_u16_m8k16v2( + __global void *baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, int2_ coord)); +SYCL_DEVICE_BUILTIN(ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, int2_ coord)); + +/// Load B +SYCL_DEVICE_BUILTIN(uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, int2_ coord)); + #undef SYCL_DEVICE_BUILTIN -struct XE_2D_LOAD //m8k16 +struct XE_2D_LOAD // m8k16 { - template - CUTE_HOST_DEVICE static void copy(const void* baseoffset, int width, int height, int pitch, int2_ coord, T* dst) - { - if constexpr(sizeof(T)==sizeof(ushort)) { - *(ushort8*)dst = __builtin_IB_subgroup_block_read_flat_u16_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord); - } else if constexpr(sizeof(T)==sizeof(uint)) { - *(uint8*)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord); - } else { - static_assert(false); - } + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, int2_ coord, + T *dst) { + if constexpr (sizeof(T) == sizeof(ushort)) { + *(ushort8 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord); + } else if constexpr (sizeof(T) == sizeof(uint)) { + *(uint8 *)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord); + } else { + static_assert(false); + } + } +}; + +/// 4X2 Block m8k16 +struct XE_2D_U16X4X2_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, int2_ coord, + T *dst) { + if constexpr (sizeof(T) == 2) { + *(ushort64 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( + long(baseoffset), width - 1, height - 1, pitch - 1, coord); + + // ((ushort8_t*)dst)[0] = sycl::bit_cast(tmp.lo.lo.lo); + // ((ushort8_t*)dst)[1] = sycl::bit_cast(tmp.lo.lo.hi); + // ((ushort8_t*)dst)[2] = sycl::bit_cast(tmp.lo.hi.lo); + // ((ushort8_t*)dst)[3] = sycl::bit_cast(tmp.lo.hi.hi); + // ((ushort8_t*)dst)[4] = sycl::bit_cast(tmp.hi.lo.lo); + // ((ushort8_t*)dst)[5] = sycl::bit_cast(tmp.hi.lo.hi); + // ((ushort8_t*)dst)[6] = sycl::bit_cast(tmp.hi.hi.lo); + // ((ushort8_t*)dst)[7] = sycl::bit_cast(tmp.hi.hi.hi); + } else { + static_assert(false); + } + } +}; + +/// 2X2 Block m8k16 +struct XE_2D_U16X2X2_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, int2_ coord, + T *dst) { + if constexpr (sizeof(T) == 2) { + ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2( + long(baseoffset), width - 1, height - 1, pitch - 1, coord); + ((ushort8_t *)dst)[0] = (ushort8_t)(tmp.lo.lo); + ((ushort8_t *)dst)[1] = (ushort8_t)(tmp.lo.hi); + ((ushort8_t *)dst)[2] = (ushort8_t)(tmp.hi.lo); + ((ushort8_t *)dst)[3] = (ushort8_t)(tmp.hi.hi); + } else { + static_assert(false); + } + } +}; + +/// 1X2 Block m8k16 +struct XE_2D_U16X1X2_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, int2_ coord, + T *dst) { + if constexpr (sizeof(T) == 2) { + ushort16 tmp = (intel_subgroup_block_read_u16_m8k16v2( + (__global void *)baseoffset, width, height, pitch, coord)); + *(ushort16 *)dst = *reinterpret_cast(&tmp); + } else { + static_assert(false); + } + } +}; + +/// 4X1 Block m8k16 +struct XE_2D_U16X4X1_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, int2_ coord, + T *dst) { + if constexpr (sizeof(T) == 2) { + ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1( + long(baseoffset), width - 1, height - 1, pitch - 1, coord); + ((ushort8_t *)dst)[0] = (ushort8_t)(tmp.lo.lo); + ((ushort8_t *)dst)[1] = (ushort8_t)(tmp.lo.hi); + ((ushort8_t *)dst)[2] = (ushort8_t)(tmp.hi.lo); + ((ushort8_t *)dst)[3] = (ushort8_t)(tmp.hi.hi); + + } else { + static_assert(false); + } + } +}; + +/// 2X1 BLock U32 k8n16 +struct XE_2D_U32X2X1_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, int2_ coord, + T *dst) { + if constexpr (sizeof(T) == 4) { + uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( + long(baseoffset), width - 1, height - 1, pitch - 1, coord); + *(uint16 *)dst = *reinterpret_cast(&tmp); + } else { + static_assert(false); + } + } +}; + +/// 2X1 Block U16 k16n16 +struct XE_2D_U16X2X1_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, int2_ coord, + T *dst) { + if constexpr (true) { + uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( + long(baseoffset), width - 1, height - 1, pitch - 1, coord); + *(uint16 *)dst = *reinterpret_cast(&tmp); + } else { + static_assert(false); } + } }; -struct XE_2D_SAVE //m8k16 +struct XE_2D_SAVE // m8k16 { - template - CUTE_HOST_DEVICE static void copy(void* baseoffset, int width, int height, int pitch, int2_ coord, const T* src) - { - if constexpr(sizeof(T)==sizeof(uint)) { - __builtin_IB_subgroup_block_write_flat_u32_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord, *(uint8*)src); - } else { - static_assert(false); - } + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, int2_ coord, const T *src) { + if constexpr (sizeof(T) == sizeof(uint)) { + __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + *(uint8 *)src); + } else { + static_assert(false); } + } }; diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 5627b722a..f2230b3a8 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -41,47 +42,48 @@ #include -namespace cute -{ +namespace cute { -template -struct Copy_Atom; +template struct Copy_Atom; template -struct Copy_Atom : Copy_Atom, CopyInternalType> -{}; +struct Copy_Atom + : Copy_Atom, CopyInternalType> {}; template struct Copy_Atom, CopyInternalType> - : Copy_Traits -{ + : Copy_Traits { using Traits = Copy_Traits; // Bit and Thr layouts from the Copy_Traits - using ThrID = typename Traits::ThrID; + using ThrID = typename Traits::ThrID; using BitLayoutSrc = typename Traits::SrcLayout; using BitLayoutDst = typename Traits::DstLayout; using BitLayoutRef = typename Traits::RefLayout; using ValType = CopyInternalType; - using ValLayoutSrc = decltype(recast_layout(BitLayoutSrc{})); - using ValLayoutDst = decltype(recast_layout(BitLayoutDst{})); - using ValLayoutRef = decltype(recast_layout(BitLayoutRef{})); + using ValLayoutSrc = + decltype(recast_layout(BitLayoutSrc{})); + using ValLayoutDst = + decltype(recast_layout(BitLayoutDst{})); + using ValLayoutRef = + decltype(recast_layout(BitLayoutRef{})); - CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); - CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); - CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), + "CopyOperation is not valid for Src of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), + "CopyOperation is not valid for Dst of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), + "CopyOperation is not valid for Ref of ValType."); static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); static constexpr int NumValDst = size<1>(ValLayoutDst{}); // Additional Trait parameters/transformations template - CUTE_HOST_DEVICE - auto - with(TraitsArgs&&... args) const { - auto traits = Traits::with(static_cast(args)...); + CUTE_HOST_DEVICE auto with(TraitsArgs &&...args) const { + auto traits = Traits::with(static_cast(args)...); return Copy_Atom{traits}; } @@ -90,40 +92,34 @@ struct Copy_Atom, CopyInternalType> // // Check and call instruction, or recurse - template - CUTE_HOST_DEVICE - void - call(Tensor const& src, - Tensor & dst) const - { + template + CUTE_HOST_DEVICE void call(Tensor const &src, + Tensor &dst) const { static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + // static_assert(is_constant::value); + // static_assert(is_constant::value); if constexpr (is_constant::value || is_constant::value) { // Dispatch to unpack to execute instruction return copy_unpack(*this, src, dst); - } else - if constexpr (is_tuple::value && - is_tuple::value) { + } else if constexpr (is_tuple::value && + is_tuple::value) { // If the size of the src/dst doesn't match the instruction, // recurse this rank-1 layout by peeling off the mode // ((A,B,C,...)) -> (A,B,C,...) return copy(*this, tensor<0>(src), tensor<0>(dst)); } else { - static_assert(dependent_false, "No instruction match and no recursion possible."); + static_assert(dependent_false, + "No instruction match and no recursion possible."); } } // Accept mutable temporaries - template - CUTE_HOST_DEVICE - void - call(Tensor const& src, - Tensor && dst) const - { + template + CUTE_HOST_DEVICE void call(Tensor const &src, + Tensor &&dst) const { return call(src, dst); } }; @@ -132,16 +128,14 @@ struct Copy_Atom, CopyInternalType> // A tiling of copy atoms // -template -struct ThrCopy; +template struct ThrCopy; template coord [Need not be 2D...] - class ShapeTiler_MN> // coord space -struct TiledCopy : Copy_Atom -{ + class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...] + class ShapeTiler_MN> // coord space +struct TiledCopy : Copy_Atom { // Layout information from the CopyAtom - using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx + using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset @@ -150,13 +144,15 @@ struct TiledCopy : Copy_Atom using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); // Layout information for the TiledCopy - using Tiler_MN = ShapeTiler_MN; + using Tiler_MN = ShapeTiler_MN; using TiledLayout_TV = LayoutCopy_TV; - using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); - using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); + using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); + using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); - CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); - CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); + CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, + "TiledCopy uses too few thrs for selected CopyAtom"); + CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, + "TiledCopy uses too few vals for selected CopyAtom"); // Tile a tensor or a layout from shape // (M,N,...) @@ -169,14 +165,14 @@ struct TiledCopy : Copy_Atom // RestM: The values tiled in M. // RestN: The values tiled in N. template - CUTE_HOST_DEVICE constexpr static - auto - tidfrg_S(STensor&& stensor) - { - CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); - - // Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout - return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); + CUTE_HOST_DEVICE constexpr static auto tidfrg_S(STensor &&stensor) { + CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), + "Rank of tensor to be partitioned too small."); + + // Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) + // layout + return tile2thrfrg(zipped_divide(stensor, Tiler_MN{}), + right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); } // Tile a tensor or a layout from shape @@ -190,14 +186,14 @@ struct TiledCopy : Copy_Atom // RestM: The values tiled in M. // RestN: The values tiled in N. template - CUTE_HOST_DEVICE constexpr static - auto - tidfrg_D(DTensor&& dtensor) - { - CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); - - // Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout - return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); + CUTE_HOST_DEVICE constexpr static auto tidfrg_D(DTensor &&dtensor) { + CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), + "Rank of tensor to be partitioned too small."); + + // Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) + // layout + return tile2thrfrg(zipped_divide(dtensor, Tiler_MN{}), + right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); } // Tile a tensor or a layout from shape @@ -205,13 +201,13 @@ struct TiledCopy : Copy_Atom // to shape // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) template - CUTE_HOST_DEVICE constexpr static - auto - tile2thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) - { + CUTE_HOST_DEVICE constexpr static auto + tile2thrfrg(Tensor &&tensor, Ref2TrgLayout const &ref2trg) { // Take the thrs/vals that the atom is interested in - // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID - auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); + // NOTE: Assumes the AtomNumThr are contiguous and identity within + // TiledThrID + auto atom_layout_TV = + zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) // Transform to the trg layout @@ -219,8 +215,9 @@ struct TiledCopy : Copy_Atom // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) // Transform the thrs mode from thrid to thr_idx - // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID - auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); + // NOTE: Assumes the AtomNumThr are contiguous and identity within + // TiledThrID + auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1, Shape<_1, _1>>{}); // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) /// ================== @@ -230,34 +227,38 @@ struct TiledCopy : Copy_Atom // ((thrid,val),(RestM,RestN,...)) // Unfold and return - return tv_tensor(make_coord(_,_), _); + return tv_tensor(make_coord(_, _), _); } - // retile_S and retile_D assume they are working with the reference layout -- they are the same + // retile_S and retile_D assume they are working with the reference layout -- + // they are the same template - CUTE_HOST_DEVICE constexpr static - auto - retile(Tensor&& tensor) - { + CUTE_HOST_DEVICE constexpr static auto retile(Tensor &&tensor) { constexpr int R = remove_cvref_t::rank; - // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation + // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref + // transformation - // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. - // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV + // Assume the first size<0>(tensor) elements are the first val_ids in + // TiledLayout_TV. Then, we only need the shape+layout of those + // size<0>(tensor) elements in TiledLayout_TV // and that shape is what we gather from the other modes of tensor auto V = size<0>(tensor); - auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{}))); + auto frg_layout_mn = upcast( + right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{}))); // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV - auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); + auto frg_layout_v = zipped_divide( + logical_product(make_layout(V), right_inverse(frg_layout_mn)), + make_layout(AtomNumVal{})); // (atom_vals,rest_vals) -> (v,m,n) /// ======= // Tile the tensor for TileFrg - auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); + auto t_tensor = + zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) // Transform the tile mode @@ -265,23 +266,19 @@ struct TiledCopy : Copy_Atom // ((atom_vals,rest_vals),(1,RM,RN,...)) // Unfold and return - return v_tensor(_, append(Int<0>{},_)); + return v_tensor(_, append(Int<0>{}, _)); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutS_TV() - { + CUTE_HOST_DEVICE constexpr static auto get_layoutS_TV() { // (M,N) -> (M,N) auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); // (thr_idx,val_idx) -> (M,N) - return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); + return tile2thrfrg(ref_S, + right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))( + _, _, Int<0>{}); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutS_MN() - { + CUTE_HOST_DEVICE constexpr static auto get_layoutS_MN() { // (thr_idx,val_idx) -> (M,N) auto layoutS_TV = get_layoutS_TV(); // (M,K) -> (thr_idx,val_idx) @@ -293,20 +290,16 @@ struct TiledCopy : Copy_Atom return cute::make_tuple(layoutS_MK, thrID_S); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutD_TV() - { + CUTE_HOST_DEVICE constexpr static auto get_layoutD_TV() { // (M,N) -> (M,N) auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); // (thr_idx,val_idx) -> (M,N) - return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); + return tile2thrfrg(ref_D, + right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))( + _, _, Int<0>{}); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutD_MN() - { + CUTE_HOST_DEVICE constexpr static auto get_layoutD_MN() { // (thr_idx,val_idx) -> (M,N) auto layoutD_TV = get_layoutD_TV(); // (M,K) -> (thr_idx,val_idx) @@ -318,82 +311,66 @@ struct TiledCopy : Copy_Atom return cute::make_tuple(layoutD_MK, thrID_D); } - template ::value)> - CUTE_HOST_DEVICE static - auto - get_slice(ThrIdx const& thr_idx) - { + template ::value)> + CUTE_HOST_DEVICE static auto get_slice(ThrIdx const &thr_idx) { return ThrCopy(thr_idx); } - template ::value)> - CUTE_HOST_DEVICE static - auto - get_thread_slice(ThrIdx const& thr_idx) - { + template ::value)> + CUTE_HOST_DEVICE static auto get_thread_slice(ThrIdx const &thr_idx) { return get_slice(thr_idx); } }; -template -struct ThrCopy -{ +template struct ThrCopy { ThrIdx thr_idx_; CUTE_HOST_DEVICE - ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} + ThrCopy(ThrIdx const &thr_idx) : thr_idx_(thr_idx) {} template - CUTE_HOST_DEVICE - auto - partition_S(STensor&& stensor) const { - //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), - // "Expected ValType for tiling SrcTensor."); - auto thr_tensor = make_tensor(static_cast(stensor).data(), TiledCopy::tidfrg_S(stensor.layout())); + CUTE_HOST_DEVICE auto partition_S(STensor &&stensor) const { + // static_assert(sizeof(typename remove_cvref_t::value_type) == + // sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + auto thr_tensor = make_tensor(static_cast(stensor).data(), + TiledCopy::tidfrg_S(stensor.layout())); return thr_tensor(thr_idx_, _, repeat>(_)); } template - CUTE_HOST_DEVICE - auto - partition_D(DTensor&& dtensor) const { - //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), - // "Expected ValType for tiling DstTensor."); - auto thr_tensor = make_tensor(static_cast(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout())); + CUTE_HOST_DEVICE auto partition_D(DTensor &&dtensor) const { + // static_assert(sizeof(typename remove_cvref_t::value_type) == + // sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + auto thr_tensor = make_tensor(static_cast(dtensor).data(), + TiledCopy::tidfrg_D(dtensor.layout())); return thr_tensor(thr_idx_, _, repeat>(_)); } template - CUTE_HOST_DEVICE static - auto - retile_S(STensor&& stensor) { - // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + CUTE_HOST_DEVICE static auto retile_S(STensor &&stensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == + // sizeof(typename TiledCopy::ValType), // "Expected ValType for tiling SrcTensor."); - return make_tensor(static_cast(stensor).data(), TiledCopy::retile(stensor.layout())); + return make_tensor(static_cast(stensor).data(), + TiledCopy::retile(stensor.layout())); } template - CUTE_HOST_DEVICE static - auto - retile_D(DTensor&& dtensor) { - // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + CUTE_HOST_DEVICE static auto retile_D(DTensor &&dtensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == + // sizeof(typename TiledCopy::ValType), // "Expected ValType for tiling DstTensor."); - return make_tensor(static_cast(dtensor).data(), TiledCopy::retile(dtensor.layout())); + return make_tensor(static_cast(dtensor).data(), + TiledCopy::retile(dtensor.layout())); } }; - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_impl(Copy_Atom const& atom, - LayoutCopy_TV const&, - Tiler const&) -{ +template +CUTE_HOST_DEVICE auto make_tiled_copy_impl(Copy_Atom const &atom, + LayoutCopy_TV const &, + Tiler const &) { return TiledCopy, LayoutCopy_TV, Tiler>{atom}; } @@ -402,55 +379,50 @@ make_tiled_copy_impl(Copy_Atom const& atom, // template -CUTE_HOST_DEVICE -auto -make_tiled_copy_A(Copy_Atom const& copy_atom, - TiledMMA const& mma) -{ - return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma))); +CUTE_HOST_DEVICE auto make_tiled_copy_A(Copy_Atom const ©_atom, + TiledMMA const &mma) { + return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), + make_shape(tile_size<0>(mma), tile_size<2>(mma))); } template -CUTE_HOST_DEVICE -auto -make_tiled_copy_B(Copy_Atom const& copy_atom, - TiledMMA const& mma) -{ - return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma))); +CUTE_HOST_DEVICE auto make_tiled_copy_B(Copy_Atom const ©_atom, + TiledMMA const &mma) { + return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), + make_shape(tile_size<1>(mma), tile_size<2>(mma))); } template -CUTE_HOST_DEVICE -auto -make_tiled_copy_C(Copy_Atom const& copy_atom, - TiledMMA const& mma) -{ - return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma))); +CUTE_HOST_DEVICE auto make_tiled_copy_C(Copy_Atom const ©_atom, + TiledMMA const &mma) { + return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), + make_shape(tile_size<0>(mma), tile_size<1>(mma))); } // returns the smallest tiled copy that can retile LayoutC_TV // for use with pipelined epilogues with subtiled stores template -CUTE_HOST_DEVICE -auto -make_tiled_copy_C_atom(Copy_Atom const& copy_atom, - TiledMMA const& mma) -{ +CUTE_HOST_DEVICE auto +make_tiled_copy_C_atom(Copy_Atom const ©_atom, + TiledMMA const &mma) { // Truncate the V-layout to just the Copy_Atom, keep the V-order auto layoutC_TV = mma.get_layoutC_TV(); - auto copy_V = Int::NumValSrc>{}; + auto copy_V = Int::NumValSrc>{}; CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV)); - auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); + auto layout_TV = composition( + layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); // Recompute tiler and restride the TV layout for the new tiler - // Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them - // Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA - auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma)); + // Tiler -- Find the active elements in the MMA tensor and generate a tiler to + // extract them Convert to the awkward by-mode tiler to preserve the modes of + // the tiled MMA + auto mma_tiler = make_shape(tile_size<0>(mma), tile_size<1>(mma)); auto mma_zeros = repeat_like(mma_tiler, Int<0>{}); auto tiler = transform(make_seq{}, [&](auto i) { - return filter(composition(make_layout(mma_tiler, replace(mma_zeros, Int<1>{})), layout_TV)); + return filter(composition( + make_layout(mma_tiler, replace(mma_zeros, Int<1>{})), layout_TV)); }); // Layout_TV -- Find the (tid,vid) -> tile coord transformation @@ -466,24 +438,22 @@ make_tiled_copy_C_atom(Copy_Atom const& copy_atom, /** Produce a TiledCopy from logical thread and values layouts. * The thread and value layouts map coordinates to thr_idx and val_idx. - * The product of these layouts is taken to produce the TV layout and the Tiler. - * Useful when threads and values need very specific mappings onto coordinates - * in the target tensors. + * The product of these layouts is taken to produce the TV layout and the + * Tiler. Useful when threads and values need very specific mappings onto + * coordinates in the target tensors. */ -template > -CUTE_HOST_DEVICE -auto -make_tiled_copy(Copy_Atom const& copy_atom, - ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx - ValLayout const& val_layout = {}) // (m,n) -> val_idx +template > +CUTE_HOST_DEVICE auto +make_tiled_copy(Copy_Atom const ©_atom, + ThrLayout const &thr_layout = {}, // (m,n) -> thr_idx + ValLayout const &val_layout = {}) // (m,n) -> val_idx { // Take the raked_products to compute the Layout_MN // (M,N) -> (thr_idx, val_idx) auto layout_mn = raked_product(thr_layout, val_layout); // (thr_idx, val_idx) -> (M,N) - auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); + auto layout_tv = right_inverse(layout_mn).with_shape( + make_shape(size(thr_layout), size(val_layout))); // Tiler for extracting relevant elements // (M,N) -> tensor coord auto tiler = product_each(shape(layout_mn)); @@ -502,28 +472,32 @@ make_tiled_copy(Copy_Atom const& copy_atom, /** Produce a TiledCopy from thread and value offset maps. * The TV Layout maps threads and values to the codomain of the data_layout. * It is verified that the intended codomain is valid within data_layout. - * Useful when threads and values don't care about owning specific coordinates, but - * care more about the vector-width and offsets between them. + * Useful when threads and values don't care about owning specific coordinates, + * but care more about the vector-width and offsets between them. */ template -CUTE_HOST_DEVICE constexpr -auto -make_cotiled_copy(Copy_Atom const& copy_atom, - AtomTVLayout const& atom_tv_layout, // atom (thr,val) -> data addr - DataLayout const& data_layout) // coord -> data addr The target layout +CUTE_HOST_DEVICE constexpr auto make_cotiled_copy( + Copy_Atom const ©_atom, + AtomTVLayout const &atom_tv_layout, // atom (thr,val) -> data addr + DataLayout const + &data_layout) // coord -> data addr The target layout { static_assert(is_static::value); static_assert(is_static::value); // data addr -> data coord Append 1:0 so off-the-ends get the stride-0 - auto inv_data_layout = make_layout(left_inverse(data_layout), Layout<_1,_0>{}); + auto inv_data_layout = + make_layout(left_inverse(data_layout), Layout<_1, _0>{}); // (tid,vid) -> data_coord auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); // Check validity - CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), - "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); + CUTE_STATIC_ASSERT_V( + coalesce(composition(data_layout, layout<1>(layout_tv_data))) == + coalesce(layout<1>(atom_tv_layout)), + "The memory pointed to by AtomTVLayout does not exist in the " + "DataLayout."); #if 0 if (thread0()) { @@ -534,15 +508,19 @@ make_cotiled_copy(Copy_Atom const& copy_atom, #endif // - // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them + // Tiler -- Find the active elements in the DATA tensor and generate a tiler + // to extract them // - // Convert to the awkward by-mode tiler to preserve the modes of the tiled DATA + // Convert to the awkward by-mode tiler to preserve the modes of the tiled + // DATA auto flat_data_shape = product_each(shape(data_layout)); auto flat_data_zeros = repeat(Int<0>{}); auto tiler = transform(make_seq{}, [&](auto i) { - return filter(composition(make_layout(flat_data_shape, replace(flat_data_zeros, Int<1>{})), layout_tv_data)); + return filter(composition( + make_layout(flat_data_shape, replace(flat_data_zeros, Int<1>{})), + layout_tv_data)); }); // @@ -567,26 +545,22 @@ make_cotiled_copy(Copy_Atom const& copy_atom, return make_tiled_copy_impl(copy_atom, layout_tv, tiler); } -// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_S(Copy_Atom const& copy_atom, - TiledCopy const& tiled_copy) -{ - return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); +// Make a TiledCopy out of the copy_atom that matches the Src-Layout of +// tiled_copy +template +CUTE_HOST_DEVICE auto make_tiled_copy_S(Copy_Atom const ©_atom, + TiledCopy const &tiled_copy) { + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), + typename TiledCopy::Tiler_MN{}); } -// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_D(Copy_Atom const& copy_atom, - TiledCopy const& tiled_copy) -{ - return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); +// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of +// tiled_copy +template +CUTE_HOST_DEVICE auto make_tiled_copy_D(Copy_Atom const ©_atom, + TiledCopy const &tiled_copy) { + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), + typename TiledCopy::Tiler_MN{}); } // @@ -595,19 +569,13 @@ make_tiled_copy_D(Copy_Atom const& copy_atom, // The logical size of a TileCopy template -CUTE_HOST_DEVICE constexpr -auto -tile_size(TiledCopy const&) -{ +CUTE_HOST_DEVICE constexpr auto tile_size(TiledCopy const &) { return size(typename TiledCopy::Tiler_MN{}); } // The number of threads involved in a TiledCopy template -CUTE_HOST_DEVICE constexpr -auto -size(TiledCopy const&) -{ +CUTE_HOST_DEVICE constexpr auto size(TiledCopy const &) { return typename TiledCopy::TiledNumThr{}; } @@ -616,60 +584,64 @@ size(TiledCopy const&) // template -CUTE_HOST_DEVICE -void -print(Copy_Atom, T> const&) -{ +CUTE_HOST_DEVICE void print(Copy_Atom, T> const &) { using Atom = Copy_Atom, T>; print("Copy_Atom\n"); - print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); - print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); - print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); - print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); - print(" ValueType: "); print(sizeof_bits::value); print("b\n"); + print(" ThrID: "); + print(typename Atom::ThrID{}); + print("\n"); + print(" ValLayoutSrc: "); + print(typename Atom::ValLayoutSrc{}); + print("\n"); + print(" ValLayoutDst: "); + print(typename Atom::ValLayoutDst{}); + print("\n"); + print(" ValLayoutRef: "); + print(typename Atom::ValLayoutRef{}); + print("\n"); + print(" ValueType: "); + print(sizeof_bits::value); + print("b\n"); } template -CUTE_HOST_DEVICE -void -print(TiledCopy const& copy, char const* pad = "") -{ +CUTE_HOST_DEVICE void print(TiledCopy const ©, + char const *pad = "") { using Copy = TiledCopy; print("TiledCopy\n"); - print(" Tiler_MN: "); print(typename Copy::Tiler_MN{}); print("\n"); - print(" TiledLayout_TV: "); print(typename Copy::TiledLayout_TV{}); print("\n"); - print(static_cast(copy)); + print(" Tiler_MN: "); + print(typename Copy::Tiler_MN{}); + print("\n"); + print(" TiledLayout_TV: "); + print(typename Copy::TiledLayout_TV{}); + print("\n"); + print(static_cast(copy)); } template -CUTE_HOST_DEVICE -void -print(ThrCopy const& thr_copy) -{ +CUTE_HOST_DEVICE void print(ThrCopy const &thr_copy) { print("ThrCopy\n"); - print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n"); + print(" ThrIdx: "); + print(thr_copy.thr_idx_); + print("\n"); print(TiledCopy{}); } template -CUTE_HOST_DEVICE -auto -print_latex(TiledCopy const& copy) -{ +CUTE_HOST_DEVICE auto print_latex(TiledCopy const ©) { auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); - print_latex_copy(layoutS_MN, thrID_S, - layoutD_MN, thrID_D); + print_latex_copy(layoutS_MN, thrID_S, layoutD_MN, thrID_D); } // MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread -template -CUTE_HOST_DEVICE -void -print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx +template +CUTE_HOST_DEVICE void +print_latex_copy(LayoutS const &S, + ThrIDS const &TS, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutD const &D, + ThrIDD const &TD) // (m,n) -> (tid,vid) and tid -> thr_idx { CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); @@ -677,59 +649,65 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and assert(size<0>(S) == size<0>(D)); assert(size<1>(S) == size<1>(D)); - char const* latex_header = + char const *latex_header = "\\documentclass{standalone}\n" "\\usepackage{tikz}\n" "\\usetikzlibrary{external}\n" "\\tikzexternalize\n" "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}",}; + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/" + ".style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; + char const *latex_footer = "\\end{tikzpicture}\n" + "\\end{document}\n"; + + char const *color_map[8] = { + "{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}", + }; // Header - printf("%% LayoutS: "); print(S); printf("\n"); - printf("%% ThrIDS : "); print(TS); printf("\n"); - printf("%% LayoutD: "); print(D); printf("\n"); - printf("%% ThrIDD : "); print(TD); printf("\n\n"); + printf("%% LayoutS: "); + print(S); + printf("\n"); + printf("%% ThrIDS : "); + print(TS); + printf("\n"); + printf("%% LayoutD: "); + print(D); + printf("\n"); + printf("%% ThrIDD : "); + print(TD); + printf("\n\n"); printf(latex_header); // S starting at 0,0 for (int i = 0; i < size<0>(S); ++i) { for (int j = 0; j < size<1>(S); ++j) { - int thrid = S(i,j) % size(TS); - int val_idx = S(i,j) / size(TS); + int thrid = S(i, j) % size(TS); + int val_idx = S(i, j) / size(TS); int thr_idx = TS(thrid); printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], - i, j, - thr_idx, val_idx); + color_map[thr_idx % 8], i, j, thr_idx, val_idx); } } // D starting at 0,size<1>(S)+3 for (int i = 0; i < size<0>(D); ++i) { for (int j = 0; j < size<1>(D); ++j) { - int thrid = D(i,j) % size(TD); - int val_idx = D(i,j) / size(TD); + int thrid = D(i, j) % size(TD); + int val_idx = D(i, j) / size(TD); int thr_idx = TD(thrid); printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], - i, j + size<1>(S) + 3, - thr_idx, val_idx); + color_map[thr_idx % 8], i, j + size<1>(S) + 3, thr_idx, val_idx); } } @@ -742,10 +720,12 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and } // D Labels for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, + j + size<1>(S) + 3, i); } for (int j = 0, i = -1; j < size<1>(D); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, + j + size<1>(S) + 3, j); } // Footer @@ -762,7 +742,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and // Config #if (__CUDACC_VER_MAJOR__ >= 12) -# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +#define CUTE_COPY_ATOM_TMA_SM90_ENABLED #endif #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 25f797306..d9e6efca5 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -1,80 +1,289 @@ #pragma once -#include #include +#include #include -namespace cute -{ - template - struct Copy_Traits - { - // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails - using ThrID = Layout<_1>; - using NumBits = Int; // hacky: does vec of 8 - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; // TODO: is _1 correct? - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - GTensor tensor; - - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const &traits, - Tensor>, SLayout> const &src, - Tensor &dst) - { - static_assert(is_rmem::value); - int H = size<0>(traits.tensor); - // int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type); - int W = size<1>(traits.tensor) * sizeof(typename TD::value_type); //TODO: inconsistent to give the size in elements but use vector for copy - auto [y, x] = src.data().coord_; - XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*dst.data()); - } - }; - - - template - struct Copy_Traits - { - // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails - using ThrID = Layout<_1>; - using NumBits = Int; // hacky: does vec of 8 - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; // TODO: is _1 correct? - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - GTensor tensor; - - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const &traits, - Tensor const &src, - Tensor>, DLayout> &dst) - { - static_assert(is_rmem::value); - int H = size<0>(traits.tensor); - int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type); - auto [y, x] = dst.data().coord_; - XE_2D_SAVE::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*src.data()); - } - }; - - template - auto make_xe_2d_copy(Tensor gtensor) - { - using GTensor = Tensor; - using Traits = Copy_Traits; - Traits traits{gtensor}; - return Copy_Atom{traits}; - } -} \ No newline at end of file +namespace cute { +template struct Copy_Traits { + // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per + // subgroup) - but static_assert fails + using ThrID = Layout<_1>; + using NumBits = Int; // hacky: does vec of 8 + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; // TODO: is _1 correct? + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + GTensor tensor; + + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + // int W = size<1>(traits.tensor) * sizeof(typename + // decltype(traits.tensor)::engine_type::value_type); + int W = + size<1>(traits.tensor) * + sizeof(typename TD::value_type); // TODO: inconsistent to give the size + // in elements but use vector for copy + auto [y, x] = src.data().coord_; + XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, + &*dst.data()); + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const &traits, Tensor const &src, + Tensor>, DLayout> &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + int W = size<1>(traits.tensor) * + sizeof(typename decltype(traits.tensor)::engine_type::value_type); + auto [y, x] = dst.data().coord_; + XE_2D_SAVE::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, + &*src.data()); + } +}; + +template struct Copy_Traits { + // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per + // subgroup) - but static_assert fails + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = + Layout, Shape<_16, _2>>>, + Stride<_16, Stride, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + // using ThrID = Layout<_1>; + // using NumBits = Int; + // // hacky: does vec of 8 + // // Map from (src-thr,src-val) to bit + // using SrcLayout = Layout>; // TODO: is _1 correct? + // // Map from (dst-thr,dst-val) to bit + // using DstLayout = Layout>; + // // Reference map from (thr,val) to bit + // using RefLayout = SrcLayout; + + GTensor tensor; + + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + // int W = size<1>(traits.tensor) * sizeof(typename + // decltype(traits.tensor)::engine_type::value_type); + int W = + size<1>(traits.tensor) * + sizeof(typename TD::value_type); // TODO: inconsistent to give the size + // in elements but use vector for copy + auto [y, x] = src.data().coord_; + XE_2D_U16X4X2_N::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, + &*dst.data()); + } +}; + +template struct Copy_Traits { + // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per + // subgroup) - but static_assert fails + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = + Layout, Shape<_16, _2>>>, + Stride<_16, Stride, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + GTensor tensor; + + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + // int W = size<1>(traits.tensor) * sizeof(typename + // decltype(traits.tensor)::engine_type::value_type); + int W = + size<1>(traits.tensor) * + sizeof(typename TD::value_type); // TODO: inconsistent to give the size + // in elements but use vector for copy + auto [y, x] = src.data().coord_; + XE_2D_U16X2X2_N::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, + &*dst.data()); + } +}; + +template struct Copy_Traits { + // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per + // subgroup) - but static_assert fails + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>>, + Stride<_16, Stride<_512, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + GTensor tensor; + + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + // int W = size<1>(traits.tensor) * sizeof(typename + // decltype(traits.tensor)::engine_type::value_type); + int W = + size<1>(traits.tensor) * + sizeof(typename TD::value_type); // TODO: inconsistent to give the size + // in elements but use vector for copy + auto [y, x] = src.data().coord_; + XE_2D_U16X1X2_N::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, + &*dst.data()); + } +}; + +template struct Copy_Traits { + // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per + // subgroup) - but static_assert fails + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = + Layout>, Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + GTensor tensor; + + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + // int W = size<1>(traits.tensor) * sizeof(typename + // decltype(traits.tensor)::engine_type::value_type); + int W = + size<1>(traits.tensor) * + sizeof(typename TD::value_type); // TODO: inconsistent to give the size + // in elements but use vector for copy + auto [y, x] = src.data().coord_; + XE_2D_U16X4X1_N::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, + &*dst.data()); + } +}; + +template struct Copy_Traits { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = + Layout>, Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + // 32 bits register file + using CopyInternalType = uint; + + GTensor tensor; + + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + // int W = size<1>(traits.tensor) * sizeof(typename + // decltype(traits.tensor)::engine_type::value_type); + int W = size<1>(traits.tensor) * + sizeof(CopyInternalType); // TODO: inconsistent to give the size in + // elements but use vector for copy + auto [y, x] = src.data().coord_; + XE_2D_U32X2X1_N::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, + &*dst.data()); + } +}; + +template struct Copy_Traits { + // // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is + // per subgroup) - but static_assert fails + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = + Layout>, Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + // 32 bits register file + using CopyInternalType = uint; + + GTensor tensor; + + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + // int W = size<1>(traits.tensor) * sizeof(typename + // decltype(traits.tensor)::engine_type::value_type); + int W = size<1>(traits.tensor) * + sizeof(CopyInternalType); // TODO: inconsistent to give the size in + // elements but use vector for copy + auto [y, x] = src.data().coord_; + XE_2D_U16X2X1_N::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, + &*dst.data()); + } +}; + +template +auto make_xe_2d_copy(Tensor gtensor) { + using GTensor = Tensor; + using Traits = Copy_Traits; + Traits traits{gtensor}; + return Copy_Atom{traits}; +} + +template +auto make_xe_2d_A_copy(Tensor gtensor) { + using GTensor = Tensor; + using Traits = Copy_Traits; + Traits traits{gtensor}; + return Copy_Atom{traits}; +} + +template +auto make_xe_2d_B_copy(Tensor gtensor) { + using GTensor = Tensor; + using Traits = Copy_Traits; + Traits traits{gtensor}; + return Copy_Atom{traits}; +} +} // namespace cute \ No newline at end of file diff --git a/include/cute/util/sycl_vec.hpp b/include/cute/util/sycl_vec.hpp index 2f523db69..77ae9cc7c 100644 --- a/include/cute/util/sycl_vec.hpp +++ b/include/cute/util/sycl_vec.hpp @@ -1,20 +1,25 @@ #pragma once -//fwd declare OCL function and OCL types +// fwd declare OCL function and OCL types #include //for sycl::vec -#ifdef __SYCL_DEVICE_ONLY__ -template using vector_t = typename sycl::vec::vector_t; -#else -template using vector_t = sycl::vec; +#ifdef __SYCL_DEVICE_ONLY__ +template using vector_t = typename sycl::vec::vector_t; +#else +template using vector_t = sycl::vec; #endif // using float8 = vector_t; // using short8 = vector_t; // using ushort8 = vector_t; -using int2_ = vector_t; //conflicts with vector_types +using int2_ = vector_t; // conflicts with vector_types // using int8 = vector_t; // using uint8 = vector_t; // using ushort16 = vector_t; // using uint16 = vector_t; +typedef ushort __attribute__((ext_vector_type(8))) ushort8_t; +typedef ushort __attribute__((ext_vector_type(16))) ushort16; +typedef ushort __attribute__((ext_vector_type(32))) ushort32; +typedef ushort __attribute__((ext_vector_type(64))) ushort64; +typedef uint __attribute__((ext_vector_type(32))) uint32;