From d913c68af1a8e7d78fa22b816aa120b91e9003c9 Mon Sep 17 00:00:00 2001 From: Finlay Marno Date: Thu, 9 Jan 2025 16:31:51 +0000 Subject: [PATCH] Reverted the change of cD to rw_coord in consumer store args --- .../epilogue/collective/xe_epilogue.hpp | 4 +-- .../cutlass/epilogue/fusion/xe_callbacks.hpp | 11 ++++---- .../cutlass/epilogue/fusion/xe_visitor.hpp | 27 ++++++++++++++++++- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 476d677cf..63a79b20d 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -321,8 +321,8 @@ class CollectiveEpilogue< tile_coord_mnkl, tiled_mma, SubgroupTileShape{}, // Epilogue tile - params.xe_load_c, - rw_coord, + params.xe_store_d, + cD, residue_mn, cD, residue_mn, diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 96dba1a0b..bfacaeda6 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -171,6 +171,7 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// template< + class CtaTileShapeMNK, class StrideAux, class CopyOpG2R, template class ActivationFn, @@ -184,7 +185,7 @@ template< using XeLinCombDeEltAct = Sm90EVT, // activation(beta * C + (alpha * acc), aux) Sm90LinearCombination, // beta * C + (alpha * acc) - XeAuxLoad // aux + XeAuxLoad // aux >; // Z = Aux @@ -215,8 +216,8 @@ struct FusionCallbacks< EpilogueTile, CopyOpG2R > : XeLinCombDeEltAct< - cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, ElementOutput_, - ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle + CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, + ElementOutput_, ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle > { using ElementOutput = ElementOutput_; @@ -224,8 +225,8 @@ struct FusionCallbacks< using Impl = XeLinCombDeEltAct< - cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, ElementOutput, - ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle + CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle >; using Operation = fusion::LinCombDeEltAct< diff --git a/include/cutlass/epilogue/fusion/xe_visitor.hpp b/include/cutlass/epilogue/fusion/xe_visitor.hpp index 44bcc6807..20b991ae4 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor.hpp @@ -49,6 +49,7 @@ using namespace cutlass::epilogue::fusion; ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class CtaTileShapeMNK, class Element, class StrideMNL, class CopyOpG2R, @@ -190,9 +191,33 @@ struct XeAuxLoad { CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto xe_copy_aux = params_ptr->xe_load_aux; - Tensor rw_coord = args.cD; Tensor trAux = make_tensor_like(args.tCrC); + using TiledMma = decltype(args.tiled_mma); + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr auto BLK_M = get<0>(CtaTileShapeMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileShapeMNK{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + + static constexpr int FragsM = SG_M / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = SG_N / get<1>(MmaAtomShape()); // B frags per sub_group + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + auto m_offset = m_coord * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + Tensor tOuti = args.tiled_copy.get_pvc_tensor( + make_coord(m_offset, n_offset, 0), + make_shape(_, Int{}, Int{}, L), + make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); + Tensor rw_coord = tOuti(_,_,_,l_coord); + return ConsumerStoreCallbacks( rw_coord, xe_copy_aux, cute::move(trAux), params_ptr );