Skip to content

Commit

Permalink
Merge commit 'f436c9ec497e3c39a94340bb0796d65aa4782bf0'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jan 9, 2025
2 parents ab58512 + f436c9e commit a15b458
Show file tree
Hide file tree
Showing 21 changed files with 606 additions and 495 deletions.
18 changes: 12 additions & 6 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,19 @@ class DialectInferLayoutInterface

// Tries to compute the encoding for the result of a reshape operation that
// makes the reshape a "nop", i.e. the same GPU threads contain the same
// elements as before the reshape. Note that this is not always possible (in
// which case you'd need to choose a different layout for the input to the
// reshape).
// elements as before the reshape using legacy layouts. This is not always
// possible (in which case we fallback to using LinearLayouts)
// In the future we'll always use LinearLayouts
virtual LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

// Check if two layouts are structurally the same, even if their names are
// different
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const = 0;

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
Expand Down
58 changes: 51 additions & 7 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,60 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"

// LinearLayoutCache Utils
using CacheKey =
std::tuple<std::vector<int64_t>, mlir::Attribute, std::optional<int32_t>>;

namespace llvm {
template <typename T> size_t hash_value(const std::vector<T> &vec) {
return hash_combine_range(vec.begin(), vec.end());
}
} // namespace llvm

namespace std {
template <> struct hash<CacheKey> {
size_t operator()(const CacheKey &key) const noexcept {
using llvm::hash_value;
size_t seed = 0;
std::apply(
[&seed](const auto &...elems) {
((seed = llvm::hash_combine(seed, hash_value(elems))), ...);
},
key);
return seed;
}
};
} // namespace std

namespace mlir::triton::gpu {

class LinearLayoutCache {
public:
std::optional<LinearLayout> get(const CacheKey &key) {
std::shared_lock lock(mutex);
auto it = cache.find(key);
if (it != cache.end()) {
return it->second;
}
return std::nullopt;
}

void set(CacheKey key, LinearLayout result) {
std::scoped_lock lock(mutex);
cache.emplace(std::move(key), std::move(result));
}

private:
std::unordered_map<CacheKey, LinearLayout> cache;
llvm::sys::SmartRWMutex<true> mutex;
};
} // namespace mlir::triton::gpu

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"

namespace mlir {
namespace triton {
namespace gpu {

namespace mlir::triton::gpu {
struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
StringRef getName() final { return "<SharedMemory>"; }
};
Expand Down Expand Up @@ -240,8 +286,6 @@ llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
llvm::SmallVector<unsigned>
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);

} // namespace gpu
} // namespace triton
} // namespace mlir
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
7 changes: 7 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def TritonGPU_Dialect : Dialect {
}
return cast<IntegerAttr>(threadsPerWarp).getInt();
}

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth);

private:
LinearLayoutCache llCache;
}];

let useDefaultTypePrinterParser = 1;
Expand Down
7 changes: 3 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,14 @@ bool ReduceOpHelper::isSupportedLayout() {
}

auto srcLayout = getSrcLayout();
if (isa<BlockedEncodingAttr>(srcLayout)) {
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
srcLayout)) {
return true;
}

if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
return mmaLayout.supportReduction();
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
return true;
}
return false;
}

Expand Down
30 changes: 14 additions & 16 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
Expand Down Expand Up @@ -701,24 +702,21 @@ LogicalResult ReshapeOp::verify() {
"encodings, or (b) neither does.");
}

if (srcEnc && !getAllowReorder()) {
Attribute inferredDstEnc;
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc,
dstTy.getShape(), inferredDstEnc,
getLoc())
.failed()) {
return emitError("This reshape is impossible without reordering, but "
"reordering is not allowed. Try choosing a different "
"encoding for the input tensor (or allow reordering).");
}
if (inferredDstEnc != dstEnc) {
return emitError("Expected result encoding ")
<< inferredDstEnc << " but was " << dstEnc;
}
if (!srcEnc || getAllowReorder()) {
return success();
}

return success();
// Check that we can infer the dst encoding from the src encoding
// and that the inferred dst encoding is the same as the given dst encoding
Attribute inferredDstEnc;
auto result =
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
inferredDstEnc, getLoc());
assert(succeeded(result));
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
getLoc());
}

//-- FpToFpOp --
Expand Down
125 changes: 83 additions & 42 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,

SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = -1;
int nonZeroIdx = 0;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// Bases can have one or zero non-zero elements
Expand All @@ -1482,7 +1482,6 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
} else if (!skipBroadcast) {
// If we've seen a non-zero basis, we double the size of the previous dim
// This is just needed to count the CTAsPerCGA
assert(nonZeroIdx != -1);
ret[nonZeroIdx] *= 2;
}
}
Expand Down Expand Up @@ -1627,12 +1626,14 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

SmallVector<unsigned>
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// We can relax this assert by calling toLinearLayout rather than
// getLinearLayout
SmallVector<int32_t> shapeVec(shape.begin(), shape.end());
assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes()));
auto ll = getLinearLayout();
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
// When broadcasting the layout the shape changes, otherwise the shape is
// the same as the shape of the tensor
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto ll = *toLinearLayout(shape);
return basesPerDim(ll, StringAttr::get(getContext(), "register"),
/*skipBroadcast=*/false);
}

// Start of Selection
Expand Down Expand Up @@ -2705,8 +2706,8 @@ struct TritonGPUInferLayoutInterface
// contains elements [a,b,c,d] before the reshape, it contains those same
// elements after the reshape, they're just "renamed".
//
// A dst encoding that satisfies this property does not exist for all inputs.
// Here are some positive and negative examples.
// Using legacy layouts, a dst encoding that satisfies this property may not
// exist. Here are some positive and negative examples.
//
// - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
// dim 1 is the fastest-changing in the dst, but the src has the opposite
Expand All @@ -2720,17 +2721,19 @@ struct TritonGPUInferLayoutInterface
// - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
// contain the same elements as before.
//
// With linear layouts, we can always find a dst encoding that satisfies
// this property. See inferReshapeOpEncoding.
//
// Users of this function require that it is symmetrical: if
// (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
// srcEnc.
LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
Attribute &dstEnc) const {
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
if (!src) {
return emitOptionalError(
loc, "Non-reordering reshape only supports BlockedEncoding");
return failure();
}

// Nop reshape; we can always infer an encoding.
Expand Down Expand Up @@ -2763,9 +2766,7 @@ struct TritonGPUInferLayoutInterface
// to handle CTASplitNum.
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
return emitOptionalError(
loc, "Non-reordering reshape does not currently support multi-CTA "
"layouts other than the default layout.");
return failure();
}

// Cowardly refuse to handle encodings where shape[dim] is not divisible by
Expand All @@ -2775,12 +2776,7 @@ struct TritonGPUInferLayoutInterface
for (int dim = 0; dim < srcShape.size(); dim++) {
if (srcShape[dim] >= subblock[dim] &&
srcShape[dim] % subblock[dim] != 0) {
return emitOptionalError(loc,
"Can't do a non-reordering reshape because "
"the size of dimension ",
dim, " (", srcShape[dim], ")",
" is not divisible by ", name, "[", dim, "]",
" = ", subblock[dim]);
return failure();
}
}
return success();
Expand All @@ -2805,11 +2801,7 @@ struct TritonGPUInferLayoutInterface
// physical order, with `a` being the most major.
for (const auto &[srcDims, dstDims] : decomp) {
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
return emitOptionalError(loc,
"Cannot do a non-reordering reshape given "
"this src encoding order. Dimensions [",
join(srcDims),
"] must be physically consecutive.");
return failure();
}
}

Expand Down Expand Up @@ -2856,11 +2848,7 @@ struct TritonGPUInferLayoutInterface
// Check that more-minor dims all have 1 in shapeRemaining.
for (int j = i + 1; j < srcDims.size(); j++) {
if (shapeRemaining[j] != 1) {
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Must use "
"up sizePerThread / threadsPerWarp / warpsPerCTA for "
"more-minor dimensions before more major-dims can use them.");
return failure();
}
}

Expand All @@ -2875,13 +2863,7 @@ struct TritonGPUInferLayoutInterface
// only if we're the most-major dimension of the chunk and in all
// future chunks, only this most-major dim has a non-1 size.
if (shapeRemaining[i] == 0 && i != 0) {
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Block "
"size in dimension ",
dim,
" is larger than the shape that dimension, but this is only "
"allowed for the most-major dimension of a reshape chunk");
return failure();
}
}
return success();
Expand Down Expand Up @@ -2971,6 +2953,65 @@ struct TritonGPUInferLayoutInterface
return success();
}

LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const override {
if (expected == got) {
return success();
}
// Check whether the encodings are structurally the same.
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
auto gotLL = triton::gpu::toLinearLayout(shape, got);
if (expectedLL != gotLL) {
return emitError(loc, "Expected result encoding ")
<< expected << " but was " << got;
}
return success();
}

LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
auto result =
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
if (succeeded(result)) {
return result;
}

// If the legacy encoding failed use LinearLayouts.
// Once LinearLayouts are more widely used, we can remove
// inferReshapeOpLegacyEncoding and simply use LLs.
auto *ctx = getContext();
auto src = triton::gpu::toLinearLayout(srcShape, srcEnc);
if (!src) {
return emitOptionalError(loc,
"src encoding does not support linear layout");
}

if (product(srcShape) != product(dstShape)) {
return emitOptionalError(loc, "numel of dst shape does not match "
"numel of src shape");
}

auto newRank = dstShape.size();
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [dim, size] :
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
newOutDims.emplace_back(dim, size);
}
auto srcOutDims = llvm::to_vector(src->getOutDimNames());
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
// before the reshape
std::reverse(srcOutDims.begin(), srcOutDims.end());
std::reverse(newOutDims.begin(), newOutDims.end());
auto dst = src->transposeOuts(srcOutDims)
.reshapeOuts(newOutDims)
.transposeOuts(standardOutDimNames(ctx, newRank));
dstEnc = LinearEncodingAttr::get(ctx, dst);
return success();
}

LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const override {
Expand Down
Loading

0 comments on commit a15b458

Please sign in to comment.