Skip to content

Commit

Permalink
Merge commit '4f6f76874ff623562903d5452d499cae3d40d448'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Nov 5, 2024
2 parents 1442ff4 + 4f6f768 commit 49a52a2
Show file tree
Hide file tree
Showing 45 changed files with 1,116 additions and 722 deletions.
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Linear Algebra Ops
:nosignatures:

dot
dot_scaled


Memory/Pointer Ops
Expand Down
2 changes: 0 additions & 2 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
9 changes: 5 additions & 4 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// Type for F8F6F4 kind of floats.
def TT_F8F6F4TypeAttr : I32EnumAttr<
"F8F6F4Type", "",
// Type for ScaleDotElemType kind of floats.
def TT_ScaleDotElemTypeAttr : I32EnumAttr<
"ScaleDotElemType", "",
[
I32EnumAttrCase<"E4M3", 0, "e4m3">,
I32EnumAttrCase<"E5M2", 1, "e5m2">,
I32EnumAttrCase<"E2M3", 2, "e2m3">,
I32EnumAttrCase<"E3M2", 3, "e3m2">,
I32EnumAttrCase<"E2M1", 4, "e2m1">
I32EnumAttrCase<"E2M1", 4, "e2m1">,
I32EnumAttrCase<"BF16", 5, "bf16">

]>{
let cppNamespace = "::mlir::triton";
Expand Down
16 changes: 8 additions & 8 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,

let arguments = (
ins
// inputs are integer types as they are packed types and we currently
// don't have a representation for those.
TT_IntTensor:$lhs,
TT_IntTensor:$rhs,
// inputs are floats if we have a type for them, otherwise (fp4),
// they are packed in pairs in an I8Tensor
RankedTensorOf<[TT_Float,I8]>:$lhs,
RankedTensorOf<[TT_Float,I8]>:$rhs,
TT_FloatTensor:$c,
TT_IntTensor:$lhs_scale,
Optional<TT_IntTensor>:$rhs_scale,
TT_F8F6F4TypeAttr:$lhs_type,
TT_F8F6F4TypeAttr:$rhs_type
RankedTensorOf<[I8]>:$lhs_scale,
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
TT_ScaleDotElemTypeAttr:$lhs_type,
TT_ScaleDotElemTypeAttr:$rhs_type
);

let results = (outs TT_FloatTensor:$d);
Expand Down
44 changes: 29 additions & 15 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false.
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
// ---- begin Ampere & Hopper ----
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
Expand Down Expand Up @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false.
llvm_unreachable("invalid operand index");
}

// ---- begin version 3 ----
if (mmaEnc.isHopper()) {
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
" is Hopper has not been implemented yet");
return $_get(context, 1, 1, 1, order, CTALayout, true);
}

// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>,
Expand Down Expand Up @@ -1237,7 +1230,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
SmallVector<int> getMMAv1Rep(int opIdx) const;
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
SmallVector<int64_t> getMMAv2OrV3RepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int kWidth, int opIdx) const;

bool supportReduction() const {
Expand Down Expand Up @@ -1336,6 +1329,27 @@ The parent field is the layout of d.
kWidth defines number of consecutive elements stored by one thread along k dimension.
Some layouts do not use this parameter, either because they have a fixed number of
elements along the K dim, or they use all elements of the tensor along the K dim.

# WGMMA Notes
We require kWidth to be provided for Hopper because the dtype at loading might be
different from the dtype at WGMMA, due to casting. The kWidth is determined by the
dtype at WGMMA.

The encoded tensor consists of operand A for possibly multiple wgmma instructions.
For each wgmma, each warp in a warp group feeds a single "warp matrix"
Each warp matrix consists of 2x2 "quads".
Each thread holds several elements in each quad. Right before a wgmma,
the sum of bitwidth of
the elements in each quad should add up to 32.

These values are stored unrolled in `elements`.
The ordering of dimensions is as follows by convention:
batch (only 1 batch for Hopper currently)
matM (m-index of the "warp matrix")
matK (k-index of the "warp matrix")
quadK (k-index of the "quad" in the core matrix)
quadM (m-index of the "quad" in the core matrix)
vecIdx (index of the element in the quad; this is always along the k-dim)
}];

let parameters = (
Expand All @@ -1346,16 +1360,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
);

let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"Type":$eltTy), [{
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper()))
return $_get(context, opIdx, parent, 0); // For MMAV1
// For MMAV2 and V3
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
unsigned kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, kWidth);
}]>
];

Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
let arguments = (ins
TT_Tensor:$src,
TT_Tensor:$scale,
TT_F8F6F4TypeAttr:$fp_type);
TT_ScaleDotElemTypeAttr:$fp_type);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

assert(!isMfmaToDotShortcut(srcTy, dstTy));
assert(cvtNeedsSharedMemory(srcTy, dstTy));

// FIXME This is NOT entirely correct
// This should be getElemOrder, but we don't have such a method
Expand Down
19 changes: 1 addition & 18 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,22 +612,6 @@ bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
return matrixDimsCompatible && bDimCompatible;
}

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
return false;
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}

// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand Down Expand Up @@ -708,8 +692,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
return !cvtReordersRegisters(srcTy, dstTy) &&
!triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
}

bool atomicNeedsSharedMemory(Value value) {
Expand Down
21 changes: 20 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ using namespace mlir::triton::gpu;

namespace mlir::triton::gpu {

namespace {

bool isDotOpTensorAndPacked(Type srcTy) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return false;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!encoding)
return false;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
// By code convention, values for Hopper's dotOp-encoded tensors are not
// packed
if (!parentEnc || parentEnc.isHopper())
return false;
return true;
}

} // namespace

Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = dyn_cast<RankedTensorType>(type))
Expand All @@ -33,7 +52,7 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
// If the parent of the dot operand is in block encoding, we don't need to
// reorder elements
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
if (!parentEncoding)
if (!parentEncoding || parentEncoding.isHopper())
return values;
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
Expand Down
33 changes: 24 additions & 9 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1099,13 +1099,18 @@ LogicalResult DotOperandEncodingAttr::verify(
return emitError() << "triton_gpu.dot_op parent paramenter cannot be null";
}
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !parentAttr.isAmpere())
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "triton_gpu.dot_op kWidth parameter can only be "
"non-zero for Ampere MMA parent";
if (kWidth == 0 && parentAttr.isAmpere())
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError()
<< "triton_gpu.dot_op kWidth parameter is mandatory for "
"Ampere MMA parent";
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "triton_gpu.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent, since Hopper WGMMA only allows first "
"operand to be in registers";
return success();
}

Expand Down Expand Up @@ -2053,17 +2058,20 @@ SmallVector<int> NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const {
int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const {
return 2 * getMMAv1Rep(opIdx)[opIdx];
}
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2OrV3RepForOperand(
ArrayRef<int64_t> shape, int bitwidth, int kWidth, int opIdx) const {
assert(isAmpere() || (isHopper() && opIdx == 0));
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();

// {batch, m, n, k}
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
: 1;
assert(isAmpere());

if (opIdx == 0)
return {numRepBatch,
Expand All @@ -2078,19 +2086,26 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
warpsPerCTA[rank - 1]))};
}
}

unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto shapePerCTA = getShapePerCTA(*this, shape);
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
// H100
if (isHopper()) {
return getTotalElemsPerThread(shape, eltTy);
assert(opIdx == 0);
auto instrMNK = getInstrShape();
int repM = ceil<unsigned>(shapePerCTA[0], instrMNK[0] * warpsPerCTAM);
int repK = ceil<unsigned>(shapePerCTA[1], instrMNK[2]);
// For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds
// kWidth elements for each quadrant. WGMMA is repeated repM * repK times.
return 4 * kWidth * repM * repK;
}
// A100
if (isAmpere()) {
auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(),
kWidth, opIdx);
auto rep = getMMAv2OrV3RepForOperand(
shapePerCTA, eltTy.getIntOrFloatBitWidth(), kWidth, opIdx);
if (opIdx == 0)
return 4 * rep[0] * rep[1] * rep[2];
if (opIdx == 1)
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ LogicalResult UpcastMXFPOp::verify() {
"operands must have the same number of dimensions, at least 2");
}

if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 ||
fpType == F8F6F4Type::E5M2)) {
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
fpType == ScaleDotElemType::E5M2)) {
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
}

// Change to support fp8 types
const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1;
const auto elems_packed = fpType == ScaleDotElemType::E2M1 ? 2 : 1;

if (xShape.back() != (32 / elems_packed) * scaleShape.back()) {
return emitOpError("last dimension of first operand must be 16 times "
Expand Down Expand Up @@ -93,7 +93,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
return emitOptionalError(loc, "expected a dotOperand encoding");
}

if (typeEncoded == F8F6F4Type::E2M1) {
if (typeEncoded == ScaleDotElemType::E2M1) {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
Expand Down
Loading

0 comments on commit 49a52a2

Please sign in to comment.