diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 56a1aa7032..39d006cc65 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -55,13 +55,19 @@ class DialectInferLayoutInterface // Tries to compute the encoding for the result of a reshape operation that // makes the reshape a "nop", i.e. the same GPU threads contain the same - // elements as before the reshape. Note that this is not always possible (in - // which case you'd need to choose a different layout for the input to the - // reshape). + // elements as before the reshape using legacy layouts. This is not always + // possible (in which case we fallback to using LinearLayouts) + // In the future we'll always use LinearLayouts virtual LogicalResult - inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const = 0; + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + // Check if two layouts are structurally the same, even if their names are + // different + virtual LogicalResult verifyLayoutsAreEqual(ArrayRef shape, + Attribute expected, Attribute got, + Location loc) const = 0; virtual LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index b81ecf103a..c4c9ebff6a 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -11,14 +11,60 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Types.h" +// LinearLayoutCache Utils +using CacheKey = + std::tuple, mlir::Attribute, std::optional>; + +namespace llvm { +template size_t hash_value(const std::vector &vec) { + return hash_combine_range(vec.begin(), vec.end()); +} +} // namespace llvm + +namespace std { +template <> struct hash { + size_t operator()(const CacheKey &key) const noexcept { + using llvm::hash_value; + size_t seed = 0; + std::apply( + [&seed](const auto &...elems) { + ((seed = llvm::hash_combine(seed, hash_value(elems))), ...); + }, + key); + return seed; + } +}; +} // namespace std + +namespace mlir::triton::gpu { + +class LinearLayoutCache { +public: + std::optional get(const CacheKey &key) { + std::shared_lock lock(mutex); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + return std::nullopt; + } + + void set(CacheKey key, LinearLayout result) { + std::scoped_lock lock(mutex); + cache.emplace(std::move(key), std::move(result)); + } + +private: + std::unordered_map cache; + llvm::sys::SmartRWMutex mutex; +}; +} // namespace mlir::triton::gpu + #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" #include "triton/Dialect/TritonGPU/IR/Ops.h.inc" -namespace mlir { -namespace triton { -namespace gpu { - +namespace mlir::triton::gpu { struct SharedMemory : public SideEffects::Resource::Base { StringRef getName() final { return ""; } }; @@ -240,8 +286,6 @@ llvm::SmallVector expandMatrixShapeWithBatch(llvm::ArrayRef s); llvm::SmallVector expandMatrixOrderWithBatch(llvm::ArrayRef o); -} // namespace gpu -} // namespace triton -} // namespace mlir +} // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index be8487be1e..95b6718b53 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -43,6 +43,13 @@ def TritonGPU_Dialect : Dialect { } return cast(threadsPerWarp).getInt(); } + + std::optional + toLinearLayout(ArrayRef shape, Attribute layout, + std::optional elemBitWidth); + + private: + LinearLayoutCache llCache; }]; let useDefaultTypePrinterParser = 1; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 2b10095fe7..8f76285249 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -225,15 +225,14 @@ bool ReduceOpHelper::isSupportedLayout() { } auto srcLayout = getSrcLayout(); - if (isa(srcLayout)) { + if (isa( + srcLayout)) { return true; } + if (auto mmaLayout = dyn_cast(srcLayout)) { return mmaLayout.supportReduction(); } - if (auto sliceLayout = dyn_cast(srcLayout)) { - return true; - } return false; } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index ab32bd992b..12a237924a 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -8,6 +8,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Tools/LinearLayout.h" #include "llvm/Support/ErrorHandling.h" namespace mlir { @@ -701,24 +702,21 @@ LogicalResult ReshapeOp::verify() { "encodings, or (b) neither does."); } - if (srcEnc && !getAllowReorder()) { - Attribute inferredDstEnc; - if (cast(&srcEnc.getDialect()) - ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, - dstTy.getShape(), inferredDstEnc, - getLoc()) - .failed()) { - return emitError("This reshape is impossible without reordering, but " - "reordering is not allowed. Try choosing a different " - "encoding for the input tensor (or allow reordering)."); - } - if (inferredDstEnc != dstEnc) { - return emitError("Expected result encoding ") - << inferredDstEnc << " but was " << dstEnc; - } + if (!srcEnc || getAllowReorder()) { + return success(); } - return success(); + // Check that we can infer the dst encoding from the src encoding + // and that the inferred dst encoding is the same as the given dst encoding + Attribute inferredDstEnc; + auto result = + cast(&srcEnc.getDialect()) + ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(), + inferredDstEnc, getLoc()); + assert(succeeded(result)); + return cast(&srcEnc.getDialect()) + ->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc, + getLoc()); } //-- FpToFpOp -- diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index ad54ff0c93..143a36ba05 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1470,7 +1470,7 @@ SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, SmallVector ret(rank, 1); auto nonZero = [](auto val) { return val != 0; }; - int nonZeroIdx = -1; + int nonZeroIdx = 0; for (const auto &basis : bases) { auto it = std::find_if(basis.begin(), basis.end(), nonZero); // Bases can have one or zero non-zero elements @@ -1482,7 +1482,6 @@ SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, } else if (!skipBroadcast) { // If we've seen a non-zero basis, we double the size of the previous dim // This is just needed to count the CTAsPerCGA - assert(nonZeroIdx != -1); ret[nonZeroIdx] *= 2; } } @@ -1627,12 +1626,14 @@ LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { SmallVector LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { - // We can relax this assert by calling toLinearLayout rather than - // getLinearLayout - SmallVector shapeVec(shape.begin(), shape.end()); - assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes())); - auto ll = getLinearLayout(); - return basesPerDim(ll, StringAttr::get(getContext(), "register")); + // When broadcasting the layout the shape changes, otherwise the shape is + // the same as the shape of the tensor + // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep + // the invariant that the shape of the LL is that of the tensor + // We choose the former for BC + auto ll = *toLinearLayout(shape); + return basesPerDim(ll, StringAttr::get(getContext(), "register"), + /*skipBroadcast=*/false); } // Start of Selection @@ -2705,8 +2706,8 @@ struct TritonGPUInferLayoutInterface // contains elements [a,b,c,d] before the reshape, it contains those same // elements after the reshape, they're just "renamed". // - // A dst encoding that satisfies this property does not exist for all inputs. - // Here are some positive and negative examples. + // Using legacy layouts, a dst encoding that satisfies this property may not + // exist. Here are some positive and negative examples. // // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so // dim 1 is the fastest-changing in the dst, but the src has the opposite @@ -2720,17 +2721,19 @@ struct TritonGPUInferLayoutInterface // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will // contain the same elements as before. // + // With linear layouts, we can always find a dst encoding that satisfies + // this property. See inferReshapeOpEncoding. + // // Users of this function require that it is symmetrical: if // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => // srcEnc. - LogicalResult - inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const override { + LogicalResult inferReshapeOpLegacyEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + Attribute &dstEnc) const { auto src = mlir::dyn_cast(srcEnc); if (!src) { - return emitOptionalError( - loc, "Non-reordering reshape only supports BlockedEncoding"); + return failure(); } // Nop reshape; we can always infer an encoding. @@ -2763,9 +2766,7 @@ struct TritonGPUInferLayoutInterface // to handle CTASplitNum. if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { - return emitOptionalError( - loc, "Non-reordering reshape does not currently support multi-CTA " - "layouts other than the default layout."); + return failure(); } // Cowardly refuse to handle encodings where shape[dim] is not divisible by @@ -2775,12 +2776,7 @@ struct TritonGPUInferLayoutInterface for (int dim = 0; dim < srcShape.size(); dim++) { if (srcShape[dim] >= subblock[dim] && srcShape[dim] % subblock[dim] != 0) { - return emitOptionalError(loc, - "Can't do a non-reordering reshape because " - "the size of dimension ", - dim, " (", srcShape[dim], ")", - " is not divisible by ", name, "[", dim, "]", - " = ", subblock[dim]); + return failure(); } } return success(); @@ -2805,11 +2801,7 @@ struct TritonGPUInferLayoutInterface // physical order, with `a` being the most major. for (const auto &[srcDims, dstDims] : decomp) { if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { - return emitOptionalError(loc, - "Cannot do a non-reordering reshape given " - "this src encoding order. Dimensions [", - join(srcDims), - "] must be physically consecutive."); + return failure(); } } @@ -2856,11 +2848,7 @@ struct TritonGPUInferLayoutInterface // Check that more-minor dims all have 1 in shapeRemaining. for (int j = i + 1; j < srcDims.size(); j++) { if (shapeRemaining[j] != 1) { - return emitOptionalError( - loc, - "Invalid src encoding for non-reordering reshape. Must use " - "up sizePerThread / threadsPerWarp / warpsPerCTA for " - "more-minor dimensions before more major-dims can use them."); + return failure(); } } @@ -2875,13 +2863,7 @@ struct TritonGPUInferLayoutInterface // only if we're the most-major dimension of the chunk and in all // future chunks, only this most-major dim has a non-1 size. if (shapeRemaining[i] == 0 && i != 0) { - return emitOptionalError( - loc, - "Invalid src encoding for non-reordering reshape. Block " - "size in dimension ", - dim, - " is larger than the shape that dimension, but this is only " - "allowed for the most-major dimension of a reshape chunk"); + return failure(); } } return success(); @@ -2971,6 +2953,65 @@ struct TritonGPUInferLayoutInterface return success(); } + LogicalResult verifyLayoutsAreEqual(ArrayRef shape, + Attribute expected, Attribute got, + Location loc) const override { + if (expected == got) { + return success(); + } + // Check whether the encodings are structurally the same. + auto expectedLL = triton::gpu::toLinearLayout(shape, expected); + auto gotLL = triton::gpu::toLinearLayout(shape, got); + if (expectedLL != gotLL) { + return emitError(loc, "Expected result encoding ") + << expected << " but was " << got; + } + return success(); + } + + LogicalResult + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + auto result = + inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc); + if (succeeded(result)) { + return result; + } + + // If the legacy encoding failed use LinearLayouts. + // Once LinearLayouts are more widely used, we can remove + // inferReshapeOpLegacyEncoding and simply use LLs. + auto *ctx = getContext(); + auto src = triton::gpu::toLinearLayout(srcShape, srcEnc); + if (!src) { + return emitOptionalError(loc, + "src encoding does not support linear layout"); + } + + if (product(srcShape) != product(dstShape)) { + return emitOptionalError(loc, "numel of dst shape does not match " + "numel of src shape"); + } + + auto newRank = dstShape.size(); + SmallVector> newOutDims; + for (auto [dim, size] : + llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) { + newOutDims.emplace_back(dim, size); + } + auto srcOutDims = llvm::to_vector(src->getOutDimNames()); + // reshapeOp assumes minor-to-major, so we need to transpose the out dims + // before the reshape + std::reverse(srcOutDims.begin(), srcOutDims.end()); + std::reverse(newOutDims.begin(), newOutDims.end()); + auto dst = src->transposeOuts(srcOutDims) + .reshapeOuts(newOutDims) + .transposeOuts(standardOutDimNames(ctx, newRank)); + dstEnc = LinearEncodingAttr::get(ctx, dst); + return success(); + } + LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, std::optional loc) const override { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 21aaca1c08..4f5d438fc7 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -879,22 +879,39 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { } std::optional -toLinearLayout(ArrayRef shape, Attribute layout, - std::optional elemBitWidth /*= std::nullopt*/) { - // Layouts are distributed or shared +TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, + std::optional elemBitWidth) { + CacheKey key{std::vector(shape.begin(), shape.end()), layout, + elemBitWidth}; + auto result = llCache.get(key); + if (result.has_value()) { + return result; + } + + // Layouts are distributed or shared in triton core if (auto distributed = dyn_cast(layout)) { - return distributed.toLinearLayout(shape); + result = distributed.toLinearLayout(shape); } else if (auto shared = dyn_cast(layout)) { if (shared.getHasLeadingOffset()) { assert(elemBitWidth.has_value()); - return sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); + result = sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); } else { - return sharedToLinearLayoutNoLeadingOffset(shape, shared); + result = sharedToLinearLayoutNoLeadingOffset(shape, shared); } } - // Third party layouts - return std::nullopt; + if (result.has_value()) { + llCache.set(std::move(key), *result); + } + return result; +} + +std::optional +toLinearLayout(ArrayRef shape, Attribute layout, + std::optional elemBitWidth /*= std::nullopt*/) { + auto *ctx = layout.getContext(); + return ctx->getLoadedDialect()->toLinearLayout( + shape, layout, elemBitWidth); } LinearLayout getLayoutWithinBlock(const LinearLayout &layout) { diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index f0b3217382..dd4ed8441e 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -43,6 +43,16 @@ struct CanonicalizeConvertFromReshape auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // If the layouts are structurally the same, the convert is trivial + auto srcType = convert.getSrc().getType(); + auto dstType = convert.getType(); + auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding()); + auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding()); + if (srcLL && dstLL && *srcLL == *dstLL) { + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder()); + return mlir::success(); + } if (isExpensiveView(convert.getSrc().getType(), op.getType())) return failure(); if (!op.getAllowReorder() || op.getEfficientLayout()) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 2b9e12c3ac..1f93d894b5 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1025,9 +1025,7 @@ void LayoutRematerialization::backwardRematerialization( // we don't handle conversions to DotOperandEncodingAttr // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - // We stop the rematerialization of linear layouts as we have to be a bit more - // careful with the heuristics for both correctness and perf - if (isa(targetType.getEncoding())) + if (isa(targetType.getEncoding())) return; Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " @@ -1069,11 +1067,8 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention - // We stop the rematerialization of linear layouts as we have to be a bit more - // careful with the heuristics for both correctness and perf RankedTensorType targetType = convertOp.getType(); - if (mlir::isa( - targetType.getEncoding())) + if (isa(targetType.getEncoding())) return; auto isExtOrBroadcastOp = [](Operation *op) { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 46dfce695c..27cb71638f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -407,14 +407,13 @@ static Attribute inferReshapeOpDstEncoding(ArrayRef srcShape, return {}; Attribute dstEnc; - if (succeeded( - srcEnc.getDialect() - .getRegisteredInterface() - ->inferReshapeOpNoReorderEncoding( - srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) { - return dstEnc; - } - return {}; + auto result = + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc, + /*loc=*/std::nullopt); + assert(succeeded(result)); + return dstEnc; } static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) { diff --git a/python/triton/testing.py b/python/triton/testing.py index 1dd079ab98..a2690cde62 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,5 +1,7 @@ import functools +import math import os +import statistics import subprocess import sys from contextlib import contextmanager @@ -64,16 +66,42 @@ def nvsmi(attrs): return ret +# pure Python implementation of np.quantile/torch.quantile +# to avoid unnecessary runtime dependency on numpy/torch + + +def _quantile(a, q): + n = len(a) + a = sorted(a) + + def get_quantile(q): + if not (0 <= q <= 1): + raise ValueError("Quantiles must be in the range [0, 1]") + point = q * (n - 1) + lower = math.floor(point) + upper = math.ceil(point) + t = point - lower + return (1 - t) * a[lower] + t * a[upper] + + return [get_quantile(q) for q in q] + + def _summarize_statistics(times, quantiles, return_mode): - import torch if quantiles is not None: - ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + ret = _quantile(times, quantiles) if len(ret) == 1: ret = ret[0] return ret if return_mode == "all": - return times.tolist() - return getattr(torch, return_mode)(times).item() + return times + elif return_mode == "min": + return min(times) + elif return_mode == "max": + return max(times) + elif return_mode == "mean": + return statistics.mean(times) + elif return_mode == "median": + return statistics.median(times) def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): @@ -86,7 +114,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod :type rep: int :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional - :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". :type return_mode: str """ import torch @@ -136,7 +164,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod end_event.record() torch.cuda.synchronize() ret += [start_event.elapsed_time(end_event) / n_repeat] - return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) + return _summarize_statistics(ret, quantiles, return_mode) def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): @@ -154,10 +182,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m :type grad_to_none: torch.tensor, optional :param quantiles: Performance percentile to return in addition to the median. :type quantiles: list[float], optional - :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". + :type return_mode: str """ assert return_mode in ["min", "max", "mean", "median", "all"] - import torch di = runtime.driver.active.get_device_interface() @@ -173,7 +201,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m start_event.record() for _ in range(5): - cache.zero_() + runtime.driver.active.clear_cache(cache) fn() if USE_WALL_TIME: di.synchronize() @@ -199,7 +227,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m for x in grad_to_none: x.grad = None # we clear the L2 cache before each run - cache.zero_() + runtime.driver.active.clear_cache(cache) if USE_WALL_TIME: di.synchronize() # record time of `fn` @@ -211,7 +239,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m # Record clocks if not USE_WALL_TIME: di.synchronize() - times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] return _summarize_statistics(times, quantiles, return_mode) diff --git a/test/Conversion/reduce_to_llvm.mlir b/test/Conversion/reduce_to_llvm.mlir new file mode 100644 index 0000000000..0bbcecbd93 --- /dev/null +++ b/test/Conversion/reduce_to_llvm.mlir @@ -0,0 +1,70 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: @reduce_linear_layout +tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1 + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2 + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3 + + // The layout looks lke + // [[ T0:0, T32:0, T0:1, T32:1, ... + // [ T4:0, T36:0, T4:1, T36:1, ... + // [ T0:2, T32:2, T0:3, T32:3, ... + // [ T4:2, T36:2, T4:3, T36:3, + // ... + // + // A reduction along axis=0 consists of adding registers (0, 2) and (1, 3) + // before shuffling. + // + // Columns along axis=0 are contained within a warp, so reduction arcoss warps + // is not needed. + + // Reduce within threads + // CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]] + // CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]] + + // Reduce within warp. + // CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31) + // CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]] + // CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31) + // CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]] + // CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31) + // CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]] + // CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31) + // CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]] + + // CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31) + // CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]] + // CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31) + // CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]] + // CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31) + // CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]] + // CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31) + // CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]] + + // CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0 + // CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1 + + %0 = "tt.reduce"(%arg0) ({ + ^bb0(%arg1: i32, %arg2: i32): + %1 = arith.addi %arg1, %arg2 : i32 + tt.reduce.return %1 : i32 + }) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> + + // CHECK-NEXT: ret { i32, i32 } [[DST1]] + tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> +} + +tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) { + %0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> + %1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)> + llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr + tt.return +} + +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index b8e8c34b33..382e4eb9c6 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2115,3 +2115,28 @@ tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #m } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [4, 4, 2], warpsPerCTA = [8, 1, 1], order = [2, 1, 0]}> +#linear = #ttg.linear<{register = [[0, 0], [0, 0], [0, 0], [0, 0]], lane = [[0, 0], [0, 1], [0, 2], [1, 0], [2, 0]], warp = [[4, 0], [8, 0], [16, 0]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: expand_dims_linear_layout +tt.func private @expand_dims_linear_layout() -> tensor<1x4xi32, #linear> { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x4xi32, #linear> + // CHECK: return %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + tt.return %1 : tensor<1x4xi32, #linear> +} + +// CHECK-LABEL: reshape_linear_layout_broadcasting +tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #linear>) -> tensor<32x4x1xbf16, #blocked> { + // CHECK-COUNT-16: extractvalue + // CHECK-COUNT-16: insertvalue + %0 = tt.reshape %arg0 : tensor<32x4xbf16, #linear> -> tensor<32x4x1xbf16, #blocked> + tt.return %0 : tensor<32x4x1xbf16, #blocked> +} + +} diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 56e078463d..3363b5c9a1 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -255,6 +255,22 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#smem = #ttg.shared_memory +// CHECK-LABEL: distribute_to_shared_st_matrix_local_store +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<128x128xf16, #mma>) { + // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK: llvm.return + %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + ttg.local_store %a, %b : tensor<128x128xf16, #mma> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} { diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index cd45d1ee05..ab124f82ab 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2829,3 +2829,26 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) { } } + +// ----- + +#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}> +#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: reduce_linear_layouts +tt.func @reduce_linear_layouts(%arg0: tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> { + // CHECK-NOT: convert_layout + %0 = ttg.convert_layout %arg0 : tensor<32x32xi32, #linear> -> tensor<32x32xi32, #blocked> + // CHECK-NEXT: tt.reduce + %1 = "tt.reduce" (%0) ({ + ^bb0(%arg1: i32, %arg2: i32): + tt.reduce.return %arg1 : i32 + // CHECK: (tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}> + }) {axis = 1 : i32} : (tensor<32x32xi32, #blocked>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = ttg.convert_layout %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> + tt.return %2 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> +} + +} diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 24b17d55be..fd02d53270 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -500,8 +500,11 @@ def get_device_interface(self): @staticmethod def is_active(): - import torch - return torch.version.hip is not None + try: + import torch + return torch.version.hip is not None + except ImportError: + return False def get_current_target(self): device = self.get_current_device() @@ -525,3 +528,6 @@ def get_empty_cache_for_benchmark(self): # It's the same as the Nvidia backend. cache_size = 256 * 1024 * 1024 return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') + + def clear_cache(self, cache): + cache.zero_() diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 1018f6112c..2d53fe52ef 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -630,8 +630,11 @@ def get_device_interface(self): @staticmethod def is_active(): - import torch - return torch.xpu.is_available() + try: + import torch + return torch.xpu.is_available() + except ImportError: + return False def get_benchmarker(self): from triton.testing import do_bench @@ -645,3 +648,6 @@ def get_empty_cache_for_benchmark(self): # doesn't contain any input data before the run cache_size = 256 * 1024 * 1024 return torch.empty(int(cache_size // 4), dtype=torch.int, device='xpu') + + def clear_cache(self, cache): + cache.zero_() diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index c6ffa50aa1..de2f019fd8 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -646,10 +646,26 @@ struct TritonIntelGPUInferLayoutInterface return success(); } + LogicalResult verifyLayoutsAreEqual(ArrayRef shape, + Attribute expected, Attribute got, + Location loc) const override { + if (expected == got) { + return success(); + } + // Check whether the encodings are structurally the same. + auto expectedLL = triton::gpu::toLinearLayout(shape, expected); + auto gotLL = triton::gpu::toLinearLayout(shape, got); + if (expectedLL != gotLL) { + return emitError(loc, "Expected result encoding ") + << expected << " but was " << got; + } + return success(); + } + LogicalResult - inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const override { + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { // TODO return failure(); } diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 0e80c8f607..75145d76f0 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -572,8 +572,11 @@ def get_device_interface(self): @staticmethod def is_active(): - import torch - return torch.cuda.is_available() and (torch.version.hip is None) + try: + import torch + return torch.cuda.is_available() and (torch.version.hip is None) + except ImportError: + return False def get_benchmarker(self): from triton.testing import do_bench @@ -587,3 +590,6 @@ def get_empty_cache_for_benchmark(self): # doesn't contain any input data before the run cache_size = 256 * 1024 * 1024 return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') + + def clear_cache(self, cache): + cache.zero_() diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 380f549cc6..749031d553 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -123,6 +123,78 @@ struct LocalLoadOpConversion } }; +LogicalResult lowerDistributedToSharedStmatrix( + Location loc, TypedValue src, MemDescType memDescType, + Value adaptorSrc, Value smemBase, const TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { + auto mmaEncoding = + dyn_cast(src.getType().getEncoding()); + if (!mmaEncoding) + return failure(); + auto sharedLayout = + cast(memDescType.getEncoding()); + if (!sharedLayout.getHasLeadingOffset()) + return failure(); + int swizzleByteSize = 0; + if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) + swizzleByteSize = 32; + else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) + swizzleByteSize = 64; + else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) + swizzleByteSize = 128; + else + return failure(); + + RankedTensorType srcTy = src.getType(); + SmallVector shape = + convertType(srcTy.getShape()); + auto order = sharedLayout.getOrder(); + if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, swizzleByteSize)) { + return failure(); + } + + auto *ctx = rewriter.getContext(); + + auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, shape, + order, swizzleByteSize); + auto llvmElemTy = typeConverter->convertType(memDescType.getElementType()); + auto smemPtrTy = ptr_ty(ctx, 3); + + auto kRegister = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + + auto regBase = applyLinearLayout(loc, rewriter, layout, + {{kRegister, i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})[0] + .second; + auto srcVals = unpackLLElements(loc, adaptorSrc, rewriter); + auto srcVec = layout.getNumConsecutiveInOut(); + for (int i = 0; i < srcVals.size(); i += srcVec) { + auto regIdx = + layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] + .second; + Value offset = xor_(regBase, i32_val(regIdx)); + auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); + vecAddr.setInbounds(true); + SmallVector inValsVec; + for (int j = 0; j < srcVec; j++) + inValsVec.push_back(srcVals[i + j]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } + return success(); +} + struct LocalAllocOpConversion : public ConvertOpToLLVMPattern { LocalAllocOpConversion(const LLVMTypeConverter &converter, @@ -136,82 +208,61 @@ struct LocalAllocOpConversion ConversionPatternRewriter &rewriter) const override { if (!op.getSrc()) return failure(); - auto mmaEncoding = dyn_cast( - op.getSrc().getType().getEncoding()); - if (!mmaEncoding) - return failure(); + MemDescType memDescType = op.getType(); auto sharedLayout = - cast(op.getType().getEncoding()); - if (!sharedLayout.getHasLeadingOffset()) - return failure(); - int swizzleByteSize = 0; - if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) - swizzleByteSize = 32; - else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) - swizzleByteSize = 64; - else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) - swizzleByteSize = 128; - else - return failure(); - - auto *ctx = rewriter.getContext(); - Location loc = op->getLoc(); - + cast(memDescType.getEncoding()); RankedTensorType srcTy = op.getSrc().getType(); - SmallVector shape = - convertType(srcTy.getShape()); - auto order = sharedLayout.getOrder(); - if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, - swizzleByteSize)) { - return failure(); - } - auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, - shape, order, swizzleByteSize); - Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); - auto smemPtrTy = ptr_ty(ctx, 3); - - auto kRegister = str_attr("register"); - auto kLane = str_attr("lane"); - auto kWarp = str_attr("warp"); - auto kBlock = str_attr("block"); - - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); - - auto regBase = applyLinearLayout(loc, rewriter, layout, - {{kRegister, i32_val(0)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, i32_val(0)}})[0] - .second; - auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - auto srcVec = layout.getNumConsecutiveInOut(); Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); - for (int i = 0; i < srcVals.size(); i += srcVec) { - auto regIdx = - layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] - .second; - Value offset = xor_(regBase, i32_val(regIdx)); - auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); - vecAddr.setInbounds(true); - SmallVector inValsVec; - for (int j = 0; j < srcVec; j++) - inValsVec.push_back(srcVals[i + j]); - Value valsVec = packLLVector(loc, inValsVec, rewriter); - targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + Value smemBase = + LLVM::getSharedMemoryBase(op.getLoc(), rewriter, targetInfo, op); + + if (lowerDistributedToSharedStmatrix(op.getLoc(), op.getSrc(), memDescType, + adaptor.getSrc(), smemBase, + typeConverter, rewriter, targetInfo) + .failed()) { + return failure(); } auto resultTy = cast(op.getType()); auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, - sharedLayout, loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + sharedLayout, op.getLoc(), rewriter); + auto retVal = + getStructFromSharedMemoryObject(op.getLoc(), smemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); } +private: + const NVIDIA::TargetInfo &targetInfo; +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + SharedMemoryObject smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + MemDescType memDescType = op.getDst().getType(); + if (lowerDistributedToSharedStmatrix( + op.getLoc(), op.getSrc(), memDescType, adaptor.getSrc(), + smemObj.getBase(), getTypeConverter(), rewriter, targetInfo) + .failed()) { + return failure(); + } + rewriter.eraseOp(op); + return success(); + } + private: const NVIDIA::TargetInfo &targetInfo; }; @@ -223,6 +274,8 @@ void mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns( // Backend optimized memory ops get higher benefit patterns.add(typeConverter, targetInfo, benefit.getBenefit() + 1); + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); patterns.add(typeConverter, benefit.getBenefit() + 1); mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo, patterns, benefit); diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 3fab64c0e8..a3cc65605f 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -77,135 +77,6 @@ int64_t getFlatIdx(ArrayRef idx, ArrayRef shape, return flatIdx; } -// Represents the many indices of one element of a tensor with a -// BlockedEncoding. -// -// The purpose of this class is we can say, if two MultiIdx's have the same -// flatFoo values before and after a reshape, then the same GPU thread contains -// the same element (and the reshape is a nop, at least for that element). -struct MultiIdx { - using Vec = SmallVector; - - // Logical index into the tensor. - Vec idx; - - // If the tensor's encoding has e.g. numPerThread = [2,2], then idxInThread - // tells us which of the four elements per thread this is. Same for idxInWarp - // and idxInCTA. - Vec idxInThread; - Vec idxInWarp; - Vec idxInCTA; - - // If the tensor's encoding defines a block of size [x,y,z], the tensor itself - // may be larger than this, comprising multiple blocks. This tells us which - // block we're in. - Vec idxOuter; - - // flatIdx is flattened according to the tensor's logical order (i.e. ignoring - // the encoding). The others are flattened according to the tensor's physical - // encoding. - int64_t flatIdx; - int64_t flatIdxInThread; - int64_t flatIdxInWarp; - int64_t flatIdxInCTA; - int64_t flatIdxOuter; -}; - -bool sameFlatIdxs(const MultiIdx &a, const MultiIdx &b) { - return a.flatIdx == b.flatIdx && // - a.flatIdxInThread == b.flatIdxInThread && - a.flatIdxInWarp == b.flatIdxInWarp && - a.flatIdxInCTA == b.flatIdxInCTA && // - a.flatIdxOuter == b.flatIdxOuter; -} - -std::string multiIdxsToString(ArrayRef> idxs) { - std::stringstream ss; - for (const auto &idxPtr : idxs) { - const MultiIdx &idx = *idxPtr; - ss // - << " [" << triton::join(idx.idx, ",") << "] (" << idx.flatIdx << ") " - << "elem=[" << triton::join(idx.idxInThread, ",") << "] (" - << idx.flatIdxInThread << ") " - << "thread=[" << triton::join(idx.idxInWarp, ",") << "] (" - << idx.flatIdxInWarp << ") " - << "warp=[" << triton::join(idx.idxInCTA, ",") << "] (" - << idx.flatIdxInCTA << ") " - << "outer=[" << triton::join(idx.idxOuter, ",") << "] (" - << idx.flatIdxOuter << ")\n"; - } - return ss.str(); -} - -std::vector> getMultiIdxs(ArrayRef shape, - BlockedEncodingAttr enc) { - using Vec = MultiIdx::Vec; - - const unsigned rank = shape.size(); - auto sizePerThread = enc.getSizePerThread(); - auto threadsPerWarp = enc.getThreadsPerWarp(); - auto warpsPerCTA = enc.getWarpsPerCTA(); - auto order = enc.getOrder(); - - Vec numBlocks; - for (int i = 0; i < rank; i++) { - numBlocks.push_back(ceil( - shape[i], sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i])); - } - - Vec idxInThread(rank, 0); - Vec idxInWarp(rank, 0); - Vec idxInCTA(rank, 0); - Vec idxOuter(rank, 0); - - int64_t nElems = product(sizePerThread) * product(threadsPerWarp) * - product(warpsPerCTA) * product(numBlocks); - - // We eventually sort this array, and if the elements are plain MultiIdx - // elements rather than pointers, we have to swap them, which ends up being - // expensive. - std::vector> elems; - elems.reserve(nElems); - - for (int64_t i = 0; i < nElems; i++) { - auto e = std::make_unique(); - e->idxInThread = idxInThread; - e->idxInWarp = idxInWarp; - e->idxInCTA = idxInCTA; - e->idxOuter = idxOuter; - - for (int i = 0; i < rank; i++) { - e->idx.push_back( // - idxInThread[i] + // - idxInWarp[i] * sizePerThread[i] + - idxInCTA[i] * sizePerThread[i] * threadsPerWarp[i] + - idxOuter[i] * sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]); - } - - e->flatIdxInThread = getFlatIdx(e->idxInThread, sizePerThread, order); - e->flatIdxInWarp = getFlatIdx(e->idxInWarp, threadsPerWarp, order); - e->flatIdxInCTA = getFlatIdx(e->idxInCTA, warpsPerCTA, order); - e->flatIdxOuter = getFlatIdx(e->idxOuter, numBlocks, order); - e->flatIdx = getFlatIdx(e->idx, shape, - llvm::to_vector(llvm::reverse(llvm::seq(rank)))); - - elems.push_back(std::move(e)); - - if (advance(idxInThread, sizePerThread, order)) { - if (advance(idxInWarp, threadsPerWarp, order)) { - if (advance(idxInCTA, warpsPerCTA, order)) { - advance(idxOuter, numBlocks, order); - } - } - } - } - llvm::sort(elems, [](const std::unique_ptr &a, - const std::unique_ptr &b) { - return a->flatIdx < b->flatIdx; - }); - return elems; -} - class InferLayoutTest : public ::testing::Test { public: InferLayoutTest() @@ -221,25 +92,12 @@ class InferLayoutTest : public ::testing::Test { /*static*/ MLIRContext InferLayoutTest::ctx; -// The optional outparam couldReshape tells the caller whether the reshape -// worked. You might want this to be a return value instead, but gtest ASSERT -// and FAIL have an implicit `return`, so only work in fns that return void. void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, std::optional expectedDstEnc, - std::optional expectSuccess, DialectInferLayoutInterface *inferLayout, - bool longErrors = true, bool *couldReshape = nullptr) { - std::unique_ptr couldReshapeStorage; - if (!couldReshape) { - couldReshapeStorage = std::make_unique(); - couldReshape = couldReshapeStorage.get(); - } - *couldReshape = false; + bool longErrors = true) { MLIRContext *ctx = srcTy.getContext(); - ASSERT_TRUE(expectSuccess || !dstTy.getEncoding()) - << "dstTy shouldn't have an expected encoding if we're expecting the " - "reshape to be impossible!"; // Capture any errors from calling inferReshapeNoOpReorderEncoding, so we can // print them if we expected the reshape to succeed but it failed. @@ -249,29 +107,17 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, { ScopedDiagnosticHandler scopedHandler( ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); - result = inferLayout->inferReshapeOpNoReorderEncoding( + result = inferLayout->inferReshapeOpEncoding( srcTy.getShape(), srcTy.getEncoding(), dstTy.getShape(), inferredEnc, UnknownLoc::get(ctx)); } - if (!expectSuccess.has_value() && !succeeded(result)) { - // We didn't know whether or not it was supposed to succeed, and it didn't. - // Test passes! - return; - } - - if (expectSuccess.has_value() && !*expectSuccess) { - EXPECT_FALSE(succeeded(result)) - << "Expected reshape to be impossible, but got dst encoding: " - << stringifyLLVMType(inferredEnc); - *couldReshape = true; - return; - } + // We expect the reshape to succeed as long as the inputs have the same + // number of elements + EXPECT_TRUE(succeeded(result)) + << "Expected reshape to succeed, but it didn't! Error(s):\n" + << join(diags, "\n"); - if (!succeeded(result)) { - FAIL() << "Expected reshape to succeed, but it didn't! Error(s):\n" - << join(diags, "\n"); - } if (auto expectedEnc = dstTy.getEncoding()) { EXPECT_EQ(inferredEnc, expectedEnc); } @@ -279,12 +125,14 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, // We know that infer(srcShape, srcEnc, dstShape) => dstEnc. Check that it // works the other way around too: infer(dstShape, dstEnc, srcShape) => // srcEnc. (This is an invariant of the inference function.) + // Even more, we check that the inferred encoding is structurally the same as + // the src encoding, showing that the inference is consistent. { std::vector diags; ScopedDiagnosticHandler scopedHandler( ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); Attribute inferredSrcEnc; - auto result = inferLayout->inferReshapeOpNoReorderEncoding( + auto result = inferLayout->inferReshapeOpEncoding( dstTy.getShape(), inferredEnc, srcTy.getShape(), inferredSrcEnc, UnknownLoc::get(ctx)); EXPECT_TRUE(succeeded(result)) @@ -292,56 +140,40 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, << " " << stringifyLLVMType(inferredEnc) << " -> " << triton::join(srcTy.getShape(), "x") << "failed:\n" << join(diags, "\n"); - if (succeeded(result)) { - EXPECT_EQ(inferredSrcEnc, srcTy.getEncoding()) - << "Inverse encoding inference (" - << triton::join(dstTy.getShape(), "x") << " " - << stringifyLLVMType(inferredEnc) << " -> " - << triton::join(srcTy.getShape(), "x") - << " gave the wrong result. Expected " - << stringifyLLVMType(srcTy.getEncoding()) << " but got " - << stringifyLLVMType(inferredSrcEnc) << ".\n"; - } + auto srcLinear = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + auto inferredSrcLinear = toLinearLayout(srcTy.getShape(), inferredSrcEnc); + EXPECT_EQ(inferredSrcLinear, srcLinear) + << "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x") + << " " << stringifyLLVMType(inferredEnc) << " -> " + << triton::join(srcTy.getShape(), "x") + << " gave the wrong result. Expected " << srcLinear->toString() + << " but " + << "got " << inferredSrcLinear->toString() << ".\n"; } - std::vector> srcMultiIdxs = - getMultiIdxs(SmallVector(srcTy.getShape()), - mlir::cast(srcTy.getEncoding())); - - std::vector> dstMultiIdxs = - getMultiIdxs(SmallVector(dstTy.getShape()), - mlir::cast(inferredEnc)); - - if (srcMultiIdxs.size() != dstMultiIdxs.size() || - !llvm::all_of(llvm::zip_equal(srcMultiIdxs, dstMultiIdxs), - [](const auto &pair) { - const auto &[a, b] = pair; - return sameFlatIdxs(*a, *b); - })) { - SCOPED_TRACE(longErrors ? "dst indices:\n" + multiIdxsToString(dstMultiIdxs) - : ""); - SCOPED_TRACE(longErrors ? "src indices:\n" + multiIdxsToString(srcMultiIdxs) - : ""); - ADD_FAILURE() << "Reified indices do not match for encodings:\n" - << " src: [" << triton::join(srcTy.getShape(), "x") << "] " - << stringifyLLVMType(srcTy.getEncoding()) << "\n" - << " dst: [" << triton::join(dstTy.getShape(), "x") << "] " - << stringifyLLVMType(inferredEnc); - } else { - *couldReshape = true; - } + // The funtional characterisation of resize is that, if we have a srcLayout + // and a dstLayout, then the flattened layouts are views of the same data + // when considered as C-contiguous. + auto makeFlattenedCContig = [](ArrayRef shape, Attribute layout) { + auto ctx = layout.getContext(); + auto linear = *toLinearLayout(shape, layout); + auto dims = standardOutDimNames(ctx, shape.size()); + std::reverse(dims.begin(), dims.end()); + return linear.transposeOuts(dims).reshapeOuts( + {{dims.back(), linear.getTotalOutDimSize()}}); + }; + EXPECT_EQ(makeFlattenedCContig(srcTy.getShape(), srcTy.getEncoding()), + makeFlattenedCContig(dstTy.getShape(), inferredEnc)); } -class InferReshapeOpNoReorderEncodingTest +class InferReshapeOpEncodingTest : public InferLayoutTest, public ::testing::WithParamInterface< - std::tuple> {}; + std::tuple> {}; -TEST_P(InferReshapeOpNoReorderEncodingTest, DoIt) { +TEST_P(InferReshapeOpEncodingTest, DoIt) { std::string srcTyStr = expandTyStr(std::get<0>(GetParam())); std::string dstTyStr = expandTyStr(std::get<1>(GetParam())); - bool expectSuccess = std::get<2>(GetParam()); auto src = mlir::parseType(srcTyStr, &ctx); if (!src) @@ -357,7 +189,7 @@ TEST_P(InferReshapeOpNoReorderEncodingTest, DoIt) { } testReshape(cast(src), cast(dst), - expectedDstEnc, expectSuccess, inferLayout, /*longErrors=*/true); + expectedDstEnc, inferLayout, /*longErrors=*/true); } // A testcase of {a, b, c} means: @@ -368,158 +200,72 @@ TEST_P(InferReshapeOpNoReorderEncodingTest, DoIt) { // encoding that makes the reshape a nop, and // - if b has an encoding, check that the inferred encoding matches b's. INSTANTIATE_TEST_SUITE_P( - Reshapes, InferReshapeOpNoReorderEncodingTest, - ::testing::ValuesIn(std::vector< - std::tuple>({ + Reshapes, InferReshapeOpEncodingTest, + ::testing::ValuesIn(std::vector>({ // Use raw strings in here so clang-format doesn't try to wrap them. {R"(T<128x64xf32, #B<{spt=[1,1], tpw=[1,32], wpc=[1,1], ord=[1,0]}>>)", - R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)", - true}, + R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)"}, {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", - R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - true}, + R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)"}, {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", - R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)", - true}, + R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)"}, {R"(T<32x32xf32, #B<{spt=[2,2], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - "T<128xf32>", false}, + "T<1024xf32>"}, {R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", - true}, + R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, {R"(T<4x32xf32, #B<{spt=[4,1], tpw=[1,32], wpc=[1,1], ord=[0,1]}>>)", - R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", - true}, + R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)"}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", - true}, + R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", - true}, + R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[0,1]}>>)", - R"(T<16x2x16x2xf32>)", true}, + R"(T<16x2x16x2xf32>)"}, // nop reshape, but the block size is 2x larger than the tensor. {R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", - R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", - true}, + R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)"}, {R"(T<2x4x2x4xf32, #B<{spt=[1,2,2,1], tpw=[1,2,1,2], wpc=[1,2,2,1], ord=[2,1,0,3]}>>)", - R"(T<4x2x2x4xf32>)", false}, + R"(T<4x2x2x4xf32>)"}, {R"(T<1x2x2x4xf32, #B<{spt=[1,32,4,4], tpw=[4,4,16,16], wpc=[8,8,8,1], ord=[0,1,2,3]}>>)", - R"(T<2x2x4x1xf32>)", false}, + R"(T<2x2x4x1xf32>)"}, {R"(T<2x2x2x2xf32, #B<{spt=[2,2,2,2], tpw=[1,1,1,1], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", - R"(T<4x4xf32>)", true}, + R"(T<4x4xf32>)"}, {R"(T<16x8xf32, #B<{spt=[1,2], tpw=[2,4], wpc=[2,1], ord=[1,0]}>>)", - R"(T<128xf32>)", true}, + R"(T<128xf32>)"}, {R"(T<16x1x8xf32, #B<{spt=[8,1,1], tpw=[2,1,1], wpc=[1,1,8], ord=[2,1,0]}>>)", - R"(T<128x1xf32>)", false}, + R"(T<128x1xf32>)"}, {R"(T<16x1x8xf32, #B<{spt=[1,1,8], tpw=[2,1,1], wpc=[8,1,1], ord=[2,1,0]}>>)", - R"(T<128x1xf32>)", true}, + R"(T<128x1xf32>)"}, {R"(T<32x32xf32, #B<{spt=[1,2], tpw=[1,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<1024xf32>)", true}, + R"(T<1024xf32>)"}, {R"(T<4x4xf32, #B<{spt=[1,1], tpw=[2,4], wpc=[2,1], ord=[0,1]}>>)", - R"(T<16xf32>)", false}, + R"(T<16xf32>)"}, {R"(T<32xf32, #B<{spt=[2], tpw=[32], wpc=[2], ord=[0]}>>)", - R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)", - true}, + R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)"}, {R"(T<2x1x2xf32, #B<{spt=[2,1,1], tpw=[2,1,2], wpc=[4,1,8], ord=[2,1,0]}>>)", - R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)", - true}, + R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)"}, }))); -TEST_F(InferLayoutTest, FuzzReshape) { - const int numTests = 1000; // Increase to get more coverage. - - std::minstd_rand rng(/*seed=*/0); - auto randPow2Vec = [&](int rank, int maxPow2) { - SmallVector ret; - for (int i = 0; i < rank; i++) { - int pow2 = std::uniform_int_distribution(0, maxPow2)(rng); - if (pow2 == maxPow2 && maxPow2 > 0) { - maxPow2--; - } - ret.push_back(1 << pow2); - } - return ret; - }; - - int numSuccess = 0; - for (int i = 0; i < numTests; i++) { - SCOPED_TRACE("iteration " + std::to_string(i)); - int rank = std::uniform_int_distribution(1, 4)(rng); - - SmallVector srcShape( - convertType(randPow2Vec(rank, /*maxPow2=*/4))); - SmallVector dstShape = srcShape; - std::shuffle(dstShape.begin(), dstShape.end(), rng); - - // Optionally merge some dimensions in dst. - for (int i = 1; i < dstShape.size(); i++) { - if (std::uniform_real_distribution(0, 1)(rng) > 1.0 / rank) { - dstShape[i - 1] *= dstShape[i]; - dstShape.erase(dstShape.begin() + i); - i--; - } - } - - SmallVector sizePerThread = randPow2Vec(rank, /*maxPow2=*/3); - SmallVector threadsPerWarp = randPow2Vec(rank, /*maxPow2=*/3); - SmallVector warpsPerCTA = randPow2Vec(rank, /*maxPow2=*/3); - - SmallVector order(llvm::to_vector(llvm::seq(rank))); - std::shuffle(order.begin(), order.end(), rng); - - auto ctaLayout = CTALayoutAttr::get( - &ctx, SmallVector(rank, 1), SmallVector(rank, 1), - llvm::to_vector(llvm::reverse(llvm::seq(rank)))); - - auto srcTy = RankedTensorType::get( - srcShape, FloatType::getF32(&ctx), - BlockedEncodingAttr::get(&ctx, sizePerThread, threadsPerWarp, - warpsPerCTA, order, ctaLayout)); - auto dstTy = RankedTensorType::get(dstShape, FloatType::getF32(&ctx)); - - bool couldReshape = false; - testReshape(srcTy, dstTy, /*expectedDstEnc=*/std::nullopt, - /*expectSuccess=*/std::nullopt, inferLayout, - /*longErrors=*/false, &couldReshape); - if (couldReshape) - numSuccess++; - } - - // We don't expect or want 100% success, but if only a tiny fraction of tests - // actually exercise the successful reshape logic, then that gives us bad - // coverage. I'm currently getting 35% success, which seems good enough, - // especially since the successful cases take a lot longer to run because of - // the MultiIdx checks (so we're spending most of our time on successful - // cases, even if they're only 1/3 of the iterations). - // - // Run ctest with --verbose to see this output. For example: - // $ cd python/build/cmake.blah.blah - // $ ninja - // $ $(git rev-parse --show-toplevel)/.venv/bin/ctest --verbose - printf("Fuzz success rate: %d/%d = %.2f%%\n", numSuccess, numTests, - 100.0 * numSuccess / numTests); -} - class AMDLayoutTest : public ::testing::Test { public: AMDLayoutTest() {