Skip to content

Commit

Permalink
Merge commit 'd9fd9c59a68dea63f1dcf8e2e20d9eda16589d68'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jan 25, 2025
2 parents 47c150d + d9fd9c5 commit e592cab
Show file tree
Hide file tree
Showing 25 changed files with 329 additions and 203 deletions.
6 changes: 6 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,12 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);

[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);

SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
triton::gpu::MemDescType srcTy,
Type elemLlvmTy,
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ unsigned getNumCTAs(Attribute layout);
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMajor
// Return the order that represents that the dot operand is in kContig
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);
bool kContig);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,

// The primary goal of this function is to efficiently store 2D tiles of a
// tensor into shared memory using the `ldmatrix` instruction.
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
Attribute dotEnc, ArrayRef<int64_t> shape);
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
bool needTrans, int32_t elemBitWidth);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global",
}];
}

def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> {
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
let summary = "wait until all the inputs are read.";
let arguments = (ins I32Attr:$pendings);
let description = [{
Expand Down
30 changes: 9 additions & 21 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

Expand Down Expand Up @@ -232,13 +231,13 @@ class MakeRangeOpAxisInfoVisitor final
}
};

template <typename OpTy>
class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
class ConstantOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<arith::ConstantOp> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
using AxisInfoVisitorImpl::AxisInfoVisitorImpl;

AxisInfo
getAxisInfo(OpTy op,
getAxisInfo(arith::ConstantOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto intAttr = dyn_cast<IntegerAttr>(op.getValue());
auto boolAttr = dyn_cast<BoolAttr>(op.getValue());
Expand Down Expand Up @@ -323,8 +322,7 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value()) {
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
std::is_same_v<OpTy, LLVM::AddOp>) {
if constexpr (std::is_same_v<OpTy, arith::AddIOp>) {
return {lhs.getConstantValue().value() +
rhs.getConstantValue().value()};
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
Expand Down Expand Up @@ -1013,15 +1011,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
// when scf.for supports integer induction variables
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
visitors.append<ConstantOpAxisInfoVisitor>();
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
AddSubOpAxisInfoVisitor<arith::AddIOp>,
AddSubOpAxisInfoVisitor<arith::SubIOp>,
AddSubOpAxisInfoVisitor<LLVM::AddOp>>();
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
visitors.append<MulIOpAxisInfoVisitor>();
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
DivOpAxisInfoVisitor<arith::DivUIOp>>();
Expand Down Expand Up @@ -1138,17 +1132,11 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,

if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (auto fun = dyn_cast<FunctionOpInterface>(op))
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
// llvm codegen check alignment to generate vector load/store
// would be nice if this wasn't the case
else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op))
if (auto fun = dyn_cast<FunctionOpInterface>(op)) {
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
else if (isa<RegionBranchOpInterface>(op)) {
} else if (isa<RegionBranchOpInterface>(op)) {
// scf::ForOp, scf::IfOp, scf::WhileOp
// Control flow operations are initialized with "unknown" state:
// the maximum possible divisibility, contiguity, and constancy.
Expand Down
26 changes: 18 additions & 8 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,9 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
} // namespace

bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

Expand All @@ -313,8 +312,6 @@ bool emitTransferBetweenRegistersAndShared(
StringAttr kWarp = str_attr("warp");

auto shape = sharedTy.getShape();
LinearLayout regLayout =
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
LinearLayout sharedLayout = triton::gpu::toLinearLayout(
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
Expand Down Expand Up @@ -360,14 +357,13 @@ bool emitTransferBetweenRegistersAndShared(
// Thus we use `pseudoinvert` instead of `invert` here for simplicity.
auto allocShape = sharedTy.getAllocShape();
LinearLayout invertAllocSharedLayout =
triton::gpu::toLinearLayout(allocShape.take_back(registerTy.getRank()),
triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()),
sharedTy.getEncoding(),
elemLlvmTy.getIntOrFloatBitWidth())
.pseudoinvert();

int numElems = regToSharedLayout.getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
Value zero = i32_val(0);
SmallVector<Value> ret;
for (int i = 0; i < numElems / vecElems; i++) {
auto regId = i32_val(i * vecElems);
Expand All @@ -379,6 +375,20 @@ bool emitTransferBetweenRegistersAndShared(
return true;
}

bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
auto regLayout = triton::gpu::toLinearLayout(
registerTy.getShape(), registerTy.getEncoding(),
elemLlvmTy.getIntOrFloatBitWidth());
return emitTransferBetweenRegistersAndShared(
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
target, perVectorCallback);
}

SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
triton::gpu::MemDescType srcTy,
Type elemLlvmTy,
Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,15 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor) {
// kMajor: if true, the matrix is fastest-running on k,
bool kContig) {
// kContig: if true, the matrix is fastest-running on k,
// otherwise it is on m (resp. n)
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
// batch (if rank == 3) is always the slowest running dimension
assert(rank == 2 || rank == 3);
assert(opIdx == 0 || opIdx == 1);
auto rowMajor = bool(opIdx) != kMajor;
auto rowMajor = bool(opIdx) != kContig;
return getMatrixOrder(rank, rowMajor);
}

Expand Down Expand Up @@ -283,7 +283,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kContig*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
Expand Down Expand Up @@ -1002,7 +1002,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
/*kContig*/ true);
}

LogicalResult DotOperandEncodingAttr::verify(
Expand Down Expand Up @@ -2004,7 +2004,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
SmallVector<unsigned>
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<unsigned>
Expand Down Expand Up @@ -2072,7 +2072,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
SmallVector<unsigned>
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<unsigned>
Expand Down Expand Up @@ -2264,7 +2264,7 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<unsigned>
Expand Down
109 changes: 53 additions & 56 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1097,80 +1097,80 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

LinearLayout chooseLdMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
SharedEncodingAttr shared,
DotOperandEncodingAttr dot,
ArrayRef<int64_t> shape) {
LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
ArrayRef<int64_t> shape, bool needTrans,
int32_t elemBitWidth) {
auto ctx = dot.getContext();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
auto rank = shape.size();
auto opIdx = dot.getOpIdx();
int kDim = opIdx == 0 ? rank - 1 : rank - 2;
int kDim = (opIdx == 0) ? rank - 1 : rank - 2;

StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
StringAttr kBlock = S("block");
StringAttr kInner = opIdx == 0 ? S("dim1") : S("dim0");
StringAttr kOuter = opIdx == 0 ? S("dim0") : S("dim1");

std::vector<std::vector<int>> basesReg = {{0, 1}, {0, 2}, {0, 4}};
std::vector<std::vector<int>> basesLane;
auto numRowsPerTile = 16;
auto numColsPerTile = 16;
int vecSize = shared.getVec();
int perPhase = shared.getPerPhase();
int maxPhase = shared.getMaxPhase();
auto warpsPerCTA = mma.getWarpsPerCTA();
// Construct a 16x16 tile consisting of 4 sub-tiles to use ldmatrix
StringAttr kInner = opIdx == 0 ? (needTrans ? S("dim0") : S("dim1"))
: (needTrans ? S("dim1") : S("dim0"));
StringAttr kOuter = opIdx == 0 ? (needTrans ? S("dim1") : S("dim0"))
: (needTrans ? S("dim0") : S("dim1"));

std::vector<std::vector<int>> basesReg;
for (int logReg = 0; logReg < llvm::Log2_32(8 * 16 / elemBitWidth);
logReg++) {
auto reg = 1 << logReg;
basesReg.push_back({0, reg});
}
std::vector<std::vector<int>> basesLane = {{1, 0}, {2, 0}, {4, 0}};
int numTileCols;
// Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix
// efficiently. opIdx=0 and opIdx=1 are handled differently.
if (opIdx == 0) {
// The matrix elements of thread 0 are distributed in the following pattern:
// The matrix elements of thread 0 are distributed in the following pattern
// (fp16):
//
// col0 col8
// row0 reg[0-1] reg[4-5]
// row8 reg[2-3] reg[6-7]
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile); logRow++) {
int row = 1 << logRow;
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
}
basesLane.push_back({0, numColsPerTile / 2});
// Expand the `register` dimension so the size of columns matches `K`.
for (int logCol = 0; logCol < llvm::Log2_32(shape[kDim] / numColsPerTile);
logCol++) {
int col = 1 << logCol;
basesReg.push_back({0, numColsPerTile * col});
if (needTrans) {
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
"supported in the transposed mode");
basesLane.push_back({0, 8});
basesLane.push_back({8, 0});
} else {
basesLane.push_back({8, 0});
basesLane.push_back({0, 8 * 16 / elemBitWidth});
}
numTileCols = 16 * 16 / elemBitWidth;
} else {
// The matrix elements of thread 0 are distributed in the following pattern:
// The matrix elements of thread 0 are distributed in the following pattern
// (fp16):
//
// col0 col8 col16 col24
// row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7]
// 8x8
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile / 2); logRow++) {
int row = 1 << logRow;
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
}
// 8x16
basesLane.push_back({0, numColsPerTile / 2});
// 8x32
basesLane.push_back({0, numColsPerTile});
// Expand the `register` dimension so the size of columns matches `K`.
for (int logCol = 0;
logCol < llvm::Log2_32(shape[kDim] / (numColsPerTile * 2)); logCol++) {
int col = 1 << logCol;
basesReg.push_back({0, (numColsPerTile * 2) * col});
if (needTrans) {
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
"supported in the transposed mode");
basesLane.push_back({8, 0});
basesLane.push_back({16, 0});
} else {
basesLane.push_back({0, 8 * 16 / elemBitWidth});
basesLane.push_back({0, 16 * 16 / elemBitWidth});
}
numTileCols = 32 * 16 / elemBitWidth;
}
auto layout = LinearLayout(
{{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}}, {kOuter, kInner});
// Expand the `register` dimension so the size of columns matches `K`.
auto layout =
LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}},
{kOuter, kInner}) *
LinearLayout::identity1D(shape[kDim] / numTileCols, kReg,
S("dim" + std::to_string(kDim)));
// Expand the `warp` dimension according to warpsPerCTA.
auto warpsPerCTA = mma.getWarpsPerCTA();
layout *= broadcastedDotOperandLayout(ctx, warpsPerCTA, mma.getWarpOrder(),
kDim, kWarp)
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
auto ret = combineCtaCgaWithShape(layout, getCTALayout(dot), shape);
return ret.transposeOuts({kInner, kOuter})
.reshapeOuts(
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
return combineCtaCgaWithShape(layout, getCTALayout(dot), shape);
}

} // anonymous namespace
Expand All @@ -1184,13 +1184,10 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
return chooseStMatrixLayoutLeadingOffset(ctx, tensorTy, swizzleByteSize);
}

LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
Attribute dotEnc, ArrayRef<int64_t> shape) {
auto shared = cast<SharedEncodingAttr>(sharedEnc);
auto dot = cast<DotOperandEncodingAttr>(dotEnc);
assert(!shared.getHasLeadingOffset() &&
"Ldmatrix does not support leading offset yet");
return chooseLdMatrixLayoutNoLeadingOffset(ctx, shared, dot, shape);
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
bool needTrans, int32_t elemBitWidth) {
auto dot = cast<DotOperandEncodingAttr>(enc);
return chooseDotLdMatrixLayout(dot, shape, needTrans, elemBitWidth);
}

} // namespace mlir::triton::gpu
Loading

0 comments on commit e592cab

Please sign in to comment.