Skip to content

Commit

Permalink
Add Ampere bfloat-float example
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed May 20, 2024
1 parent aec9d0c commit 7180fae
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 5 deletions.
5 changes: 5 additions & 0 deletions examples/sycl/ampere/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ cutlass_example_add_executable(
ampere_gemm_fp16_fp16_fp32_tensor_op_fp32
ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp
)

cutlass_example_add_executable(
ampere_gemm_bf16_bf16_fp32_tensor_op_fp32
ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,12 @@ int main(int argc, const char** argv)
// to use a GPU other than that with device ID 0.
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

bool passed;

// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementInputA = half_t; // <- data type of elements in input matrix A
using ElementInputB = half_t; // <- data type of elements in input matrix B
using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D

using LayoutA = cutlass::layout::ColumnMajor;
Expand All @@ -82,7 +80,7 @@ int main(int argc, const char** argv)
using TileShape = Shape<_128, _128, _32>;

using TiledMma = TiledMMA<
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>,
Layout<Shape<_2,_2,_1>>, // 2x2x1 thread group
Tile<_32,_32,_16>>; // 32x32x8 MMA for LDSM, 1x2x1 value group

Expand Down
75 changes: 75 additions & 0 deletions examples/sycl/ampere/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,78 @@ template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<half_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{};

/////////////////////////////////////////////////////////////////////////

// Bfloat

/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cutlass::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape < _8,_64>,
Stride<_64, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_16,_8>,
Stride< _8,_1>>{},
Layout<Shape < _1,_8>>{}));
};

/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape <_64, _8>,
Stride< _1,_64>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U16x8_LDSM_T, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_16, _8>,
Stride< _1,_16>>{},
Layout<Shape < _8, _1>>{}));
};

/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<2,3,3>{},
Layout<Shape < _8,_32>,
Stride<_32, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_32,_4>,
Stride< _4,_1>>{},
Layout<Shape < _1,_8>>{}));
};

// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands

// Operand B - Column-Major (K-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
{};

// Operand B - Row-Major (N-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{};

0 comments on commit 7180fae

Please sign in to comment.