Skip to content

Commit

Permalink
Add Ampere tf32 example
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed May 20, 2024
1 parent 00db8f6 commit 856985b
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/sycl/ampere/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ cutlass_example_add_executable(
ampere_gemm_bf16_bf16_fp32_tensor_op_fp32
ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cu
)

cutlass_example_add_executable(
ampere_gemm_tf32_tf32_fp32_tensor_op_fp32
ampere_gemm_tf32_tf32_fp32_tensor_op_fp32.cu
)
153 changes: 153 additions & 0 deletions examples/sycl/ampere/ampere_gemm_tf32_tf32_fp32_tensor_op_fp32.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/

#include "../common/example_runner.hpp"
#include "gemm_configuration.hpp"

int main(int argc, const char** argv)
{
//
// Parse options
//

Options options;

options.parse(argc, argv);

if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}

if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}

//
// Run examples
//

// The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;

// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementInputA = tfloat32_t; // <- data type of elements in input matrix A
using ElementInputB = tfloat32_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D

using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::ColumnMajor;

using TileShape = Shape<_128, _128, _32>;

using TiledMma = TiledMMA<
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
Layout<Shape<_2,_2,_1>, Stride<_2, _1, _1>>, // 2x2x1 thread group
Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group

static constexpr int kAlignmentA = 4;
using DefaultOperandA = DefaultGemm_TensorOpSm80_OperandA<
ElementInputA, LayoutA, kAlignmentA, 32>;
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;

static constexpr int kAlignmentB = 4;
using DefaultOperandB = DefaultGemm_TensorOpSm80_OperandB<
ElementInputB, LayoutB, kAlignmentB, 32>;
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;

using Stages = Int<3>;

// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function

using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsync<Stages{}>;

// Define strides (mixed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using StrideC = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideC_t<LayoutD>;

using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
StrideC,
StrideD,
EpilogueOp,
cutlass::gemm::EpilogueDefault>;

// Mainloop
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
DispatchPolicy,
TileShape,
ElementInputA,
StrideA,
ElementInputB,
StrideB,
TiledMma,
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
>;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

ExampleRunner<Gemm> runner;

runner.run(options, hw_info);

return 0;
}
73 changes: 73 additions & 0 deletions examples/sycl/ampere/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,76 @@ template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{};

/////////////////////////////////////////////////////////////////////////

// TFloat32

/// Operand A - Row-major (K-major) (kBlock = 32)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, cutlass::layout::RowMajor, 4, 32>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<3,2,3>{},
Layout<Shape < _8,_32>,
Stride<_32, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, tfloat32_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, tfloat32_t>{},
Layout<Shape <_16,_8>,
Stride< _8,_1>>{},
Layout<Shape < _1,_4>>{}));
};

/// Operand A - Row-major (K-major) (kBlock = 16)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, cutlass::layout::RowMajor, 4, 16>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<2,2,3>{},
Layout<Shape < _8,_16>,
Stride<_16, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, tfloat32_t>;
// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, tfloat32_t>{},
Layout<Shape <_32,_4>,
Stride< _4,_1>>{},
Layout<Shape < _1,_4>>{}));
};

/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, cutlass::layout::ColumnMajor, 4, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<2,3,2>{},
Layout<Shape <_32, _8>,
Stride< _1,_32>>{}));
using SmemCopyAtom = Copy_Atom<UniversalCopy<tfloat32_t>, tfloat32_t>;
// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, tfloat32_t>{},
Layout<Shape <_16, _8>,
Stride< _1,_16>>{},
Layout<Shape < _4, _1>>{}));
};

// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands

// Operand B - Column-Major (K-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<tfloat32_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, cutlass::layout::RowMajor, Alignment, SizeK>
{};

// Operand B - Row-Major (N-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<tfloat32_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{};

0 comments on commit 856985b

Please sign in to comment.