-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement full feature of copy/gemm for PVC backend #174
base: sycl-develop
Are you sure you want to change the base?
Conversation
8273148
to
1b76fe3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding support for all the data layouts. I’ve left some comments.
1b76fe3
to
b7f4d66
Compare
d549458
to
b4944b9
Compare
1b89356
to
29543ce
Compare
e5e22ab
to
b8586e6
Compare
Please don't force push PR's when they are already being reviewed. Makes reviewing harder because one can't look at the added commits since last review. |
Implement Feature: 1. Implement full features of copy/MMA for PVC backend We don't implement full copy/gemm functions before this commit because the cutlass cute copy/MMA API is not fully compatible with PVC backend. The register layout loaded by PVC subgroup intrinsic doesn't satisfy the cute::gemm requirement which leads to problems including but not limited to: (1) GEMM can only support specific combination of tile sizes and copy traits. GEMM functionality will be wrong if you try to change tile size configuration or copy traits. For example, the case "examples/sycl/pvc/pvc_gemm.cpp" will fail if you change sg_tile_k from 32 to 64. So we must retile the register data layout before cute::gemm. (2) We have to hardcode to change the register layout to satisfy the requirement of cutlass cute APIs. For example the data from “partition_fragment_B” need to be hardcoded. 2. Support different GEMM layout and data type (1) Support different combinations of RowMajor and ColumnMajor for matrix A and B. Refer to test/unit/cute/intel_xe/gemm_data_type.cpp. (2) Add GEMM test case for int8/uint8/fp16/bf16. Refer to test/unit/cute/intel_xe/gemm_layout.cpp. This PR will implement above features and keep performance not dropped. Refine Code 1. Refine layout convention for gemm. For GEMM C = A x B; let A is (m, k, l), B is (n, k, l), C is (m, n, l), hide backend related differences inside implementation of PVC copy traits(copy_traits_xe.hpp), make it easier for upper-level users to write code for Intel Xe GPU according to cutlass usage habits, don’t let user hardcode for Intel Xe GPU. 2. Refine the API "get_pvc_tensor" Before this PR, we mix K-slicing and coordinate tensor together, which make the interface parameters unclear and difficult to understand. actuualy "K-slicing" is for MMA use, while "coordinate tensor" is only for copy, they are two things, we must keep them functionally independent, so we supply a helper function "append_pvc_tensor".
I pushed wrong branch and just restore it. |
@@ -80,16 +80,49 @@ using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration< | |||
TiledMMA<MMAAtom, Layout<Shape<_1,_4,_1>>>, | |||
XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V>; | |||
|
|||
using PvcGemmBF16BF16FP32_RRR_6 = cutlass::gemm::device::GemmConfiguration< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RRR stands for ABC data layout. R is row major. C column major.
In this case it should be PvcGemmBF16BF16FP32_RCR
@@ -497,4 +497,107 @@ gemm(MMA_Atom<MMA> const& mma, | |||
} | |||
} | |||
|
|||
#if defined(SYCL_INTEL_TARGET) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
@@ -2022,4 +2068,103 @@ namespace detail | |||
} | |||
} // end namespace detail | |||
|
|||
template <class TiledCopy, class ThrIdx> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is all this needed?
static constexpr Params | ||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { | ||
to_underlying_arguments(TensorMKL const & tensorA, TensorNKL const &tensorB, void* workspace) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why changing the API?
for (int i = 0; i < SG_K / SubgroupSize; i++) { | ||
cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum); | ||
} | ||
cute::gemm(tiled_mma, gmem_tiled_copy_a, gmem_tiled_copy_b, tCrA, tCrB, accum); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why gemm need the copy operation?
What happen with src_accum?
static constexpr auto construct_mkl_tensor_A(Arguments const &args) { | ||
using LayoutA = cutlass::detail::StrideToLayoutTagA_t<StrideA>; | ||
|
||
auto [M, N, K, L] = args.problem_shape; | ||
|
||
if constexpr (std::is_same_v<LayoutA, cutlass::layout::RowMajor>) { | ||
return make_tensor(make_gmem_ptr(static_cast<ElementA const*>(args.mainloop.ptr_A)), | ||
make_layout(make_shape(M,K,L),make_stride((int64_t)K, _1{}, (int64_t)M * K))); | ||
} else { | ||
return make_tensor(make_gmem_ptr(static_cast<ElementA const*>(args.mainloop.ptr_A)), | ||
make_layout(make_shape(M,K,L), make_stride(_1{}, (int64_t)M, (int64_t)M * K))); | ||
} | ||
} | ||
|
||
static constexpr auto construct_nkl_tensor_B(Arguments const &args) { | ||
using LayoutB = cutlass::detail::StrideToLayoutTagB_t<StrideB>; | ||
|
||
auto [M, N, K, L] = args.problem_shape; | ||
|
||
if constexpr (std::is_same_v<LayoutB, cutlass::layout::RowMajor>) { | ||
return make_tensor(make_gmem_ptr(static_cast<ElementB const*>(args.mainloop.ptr_B)), | ||
make_layout(make_shape(N,K,L), make_stride(_1{}, (int64_t)N, (int64_t)N * K))); | ||
} else { | ||
return make_tensor(make_gmem_ptr(static_cast<ElementB const*>(args.mainloop.ptr_B)), | ||
make_layout(make_shape(N,K,L), make_stride((int64_t)K, _1{}, (int64_t)N * K))); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed. StrideA and StrideB are provided as args by the user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are many different changes in this PR. It would be much easier to review if you separated each change into a different PR.
No description provided.