From 30e7a2e12b449a3f5a4c3f9b85fa2fcb9e1f32ba Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Tue, 29 Oct 2024 08:46:53 -0400 Subject: [PATCH 01/13] [Proton] Adding Sorting of Kernels (#4987) --- third_party/proton/README.md | 9 +- third_party/proton/proton/viewer.py | 20 ++- .../proton/test/example_leaf_nodes.json | 168 ++++++++++++++++++ third_party/proton/test/test_viewer.py | 16 ++ 4 files changed, 210 insertions(+), 3 deletions(-) create mode 100644 third_party/proton/test/example_leaf_nodes.json diff --git a/third_party/proton/README.md b/third_party/proton/README.md index fede11cedb..b2d297f0ed 100644 --- a/third_party/proton/README.md +++ b/third_party/proton/README.md @@ -140,9 +140,16 @@ By default, proton profiles are in the *json* format and can be read by *Hatchet pip install llnl-hatchet proton-viewer -m time/s ``` - NOTE: `pip install hatchet` does not work because the API is slightly different. +### Visualizing sorted profile data +In addition visualizing the profile data on terminal through Hatchet. A sorted list of the kernels by the first metric can be done using the --print-sorted flag with proton-viewer + +```bash +proton-viewer -m time/ns,time/% --print-sorted +``` +prints the sorted kernels by the time/ns since it is the first listed. + More options can be found by running the following command. ```bash diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index fe7c98807c..82a5c178f3 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -189,7 +189,7 @@ def filter_frames(gf, include=None, exclude=None, threshold=None, metric=None): return gf -def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=100, format=None): +def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=100, format=None, print_sorted=False): with open(filename, "r") as f: gf, raw_metrics, device_info = get_raw_metrics(f) gf = format_frames(gf, format) @@ -199,6 +199,15 @@ def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=1 # TODO: generalize to support multiple metrics, not just the first one gf = filter_frames(gf, include, exclude, threshold, metrics[0]) print(gf.tree(metric_column=metrics, expand_name=True, depth=depth, render_header=False)) + if print_sorted: + print("Sorted kernels by metric " + metrics[0].strip("(inc)")) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + for row in range(1, len(sorted_df)): + if len(sorted_df.iloc[row]['name']) > 100: + kernel_name = sorted_df.iloc[row]['name'][:100] + "..." + else: + kernel_name = sorted_df.iloc[row]['name'] + print("{:105} {:.4}".format(kernel_name, sorted_df.iloc[row][metrics[0]])) emit_warnings(gf, metrics) @@ -298,6 +307,12 @@ def main(): - function_line: include the function name and line number. - file_function: include the file name and function name. """) + argparser.add_argument( + "--print-sorted", + action='store_true', + default=False, + help="Sort output by metric value instead of chronologically", + ) args, target_args = argparser.parse_known_args() assert len(target_args) == 1, "Must specify a file to read" @@ -309,12 +324,13 @@ def main(): threshold = args.threshold depth = args.depth format = args.format + print_sorted = args.print_sorted if include and exclude: raise ValueError("Cannot specify both include and exclude") if args.list: show_metrics(file_name) elif metrics: - parse(metrics, file_name, include, exclude, threshold, depth, format) + parse(metrics, file_name, include, exclude, threshold, depth, format, print_sorted) if __name__ == "__main__": diff --git a/third_party/proton/test/example_leaf_nodes.json b/third_party/proton/test/example_leaf_nodes.json new file mode 100644 index 0000000000..5930664dd2 --- /dev/null +++ b/third_party/proton/test/example_leaf_nodes.json @@ -0,0 +1,168 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_1_2_2", + "type": "function" + }, + "metrics": { + "count": 402, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 78190414 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_1_3_1", + "type": "function" + }, + "metrics": { + "count": 502, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 24125138 + } + } + ], + "frame": { + "name": "kernel_1_2_1", + "type": "function" + }, + "metrics": { + "bytes": 3997237248, + "flops": 1534939103232 + } + } + ], + "frame": { + "name": "kernel_1_1_1", + "type": "function" + }, + "metrics": {} + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_2_2_2", + "type": "function" + }, + "metrics": { + "count": 120, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 23174888 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_2_3_1", + "type": "function" + }, + "metrics": { + "count": 149, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 1040322 + } + } + ], + "frame": { + "name": "kernel_2_2_1", + "type": "function" + }, + "metrics": { + "bytes": 58589184, + "flops": 4999610368 + } + } + ], + "frame": { + "name": "kernel_2_1_1", + "type": "function" + }, + "metrics": {} + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_3_2_2", + "type": "function" + }, + "metrics": { + "count": 480, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 93036508 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_3_2_1", + "type": "function" + }, + "metrics": { + "count": 599, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 6306402 + } + } + ], + "frame": { + "name": "kernel_3_2_1", + "type": "function" + }, + "metrics": { + "bytes": 529956864, + "flops": 67834478592 + } + } + ], + "frame": { + "name": "kernel_3_1_1", + "type": "function" + }, + "metrics": {} + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "bytes": 0, + "count": 0, + "flops": 0, + "time (ns)": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + } + } + } +] diff --git a/third_party/proton/test/test_viewer.py b/third_party/proton/test/test_viewer.py index b2d4d39f9b..e5d4672160 100644 --- a/third_party/proton/test/test_viewer.py +++ b/third_party/proton/test/test_viewer.py @@ -7,6 +7,7 @@ cuda_example_file = file_path.replace("test_viewer.py", "example_cuda.json") hip_example_file = file_path.replace("test_viewer.py", "example_hip.json") frame_example_file = file_path.replace("test_viewer.py", "example_frame.json") +leaf_example_file = file_path.replace("test_viewer.py", "example_leaf_nodes.json") def test_help(): @@ -15,6 +16,21 @@ def test_help(): assert ret == 0 +def test_sort(): + with open(leaf_example_file, "r") as f: + gf, raw_metrics, device_info = get_raw_metrics(f) + gf = format_frames(gf, None) + gf.update_inclusive_columns() + metrics = ["time/s", "time/ms", "time/us", "time/ns"] + metrics = derive_metrics(gf, metrics, raw_metrics, device_info) + gf = filter_frames(gf, None, None, None, metrics[0]) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + actual = sorted_df.iloc[0:5]['name'].values + expected = ['ROOT', 'kernel_1_1_1', 'kernel_3_1_1', 'kernel_3_2_2', 'kernel_1_2_2'] + assert len(actual) == len(expected) + assert all([a == b for a, b in zip(actual, expected)]) + + @pytest.mark.parametrize("option", ["full", "file_function_line", "function_line", "file_function"]) def test_format_frames(option): with open(frame_example_file, "r") as f: From ef614882219f690a613cbfcad8f11136b45a8052 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Tue, 29 Oct 2024 16:39:54 +0000 Subject: [PATCH 02/13] [FRONTEND] Fix transpose with tuple dims (#5006) Fixes #4879 --- python/test/unit/language/test_core.py | 3 +++ python/triton/language/core.py | 1 + 2 files changed, 4 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 1cebd25779..43f4cc8c12 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1111,6 +1111,9 @@ def kernel(): a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + a = tl.arange(0, 64).reshape(2, 4, 8).trans((2, 1, 0)) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + a = tl.arange(0, 64).view(2, 4, 8) tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e2c57b388b..c9f22e3991 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1290,6 +1290,7 @@ def trans(input: tensor, *dims, _builder=None): :py:func:`permute` is equivalent to this function, except it doesn't have the special case when no permutation is specified. """ + dims = _unwrap_iterable(dims) if not dims: dims = (1, 0) return semantic.permute(input, dims, _builder) From 69f656cda7f9b2fed998602b0d4b1cb00d5e00f1 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Tue, 29 Oct 2024 18:46:32 +0100 Subject: [PATCH 03/13] [AMD] remove redundant LDS bypass checks (#5002) This commit removes special cases for MFMA -> Dot Operand LDS shortcuts. Now it is supported by common linear layout infrastructure. No tests are added, mfma-shortcut.mlir already testing this. --- include/triton/Analysis/Utility.h | 2 - lib/Analysis/Allocation.cpp | 2 +- lib/Analysis/Utility.cpp | 19 +------ .../ConvertLayoutOpToLLVM.cpp | 51 ------------------- .../DecomposeUnsupportedConversions.cpp | 6 ++- 5 files changed, 6 insertions(+), 74 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 37d24ac929..df6029db0d 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -214,8 +214,6 @@ bool atomicNeedsSharedMemory(Value result); bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); -bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); - // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 276a6e7004..665b97aeeb 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - assert(!isMfmaToDotShortcut(srcTy, dstTy)); + assert(cvtNeedsSharedMemory(srcTy, dstTy)); // FIXME This is NOT entirely correct // This should be getElemOrder, but we don't have such a method diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 30ba11c317..aa9f8b01ea 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -605,22 +605,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { return matrixDimsCompatible && bDimCompatible; } -bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - auto mfmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - if (mfmaLayout == nullptr || dotOperandLayout == nullptr) - return false; - // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is - // improved. In addition, we can enable this shortcut for regular MFMA - // layout when opIdx == 1. - return mfmaLayout.getWarpsPerCTA()[1] == 1 && - dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && - dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && - dotOperandLayout.getParent() == mfmaLayout && - (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && - (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); -} - // For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { @@ -738,8 +722,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && - !isMfmaToDotShortcut(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index b7ee4efc72..d3ffaed2e8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -115,56 +115,6 @@ struct LocalLoadOpConversion } }; -struct ConvertLayoutOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern< - triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = op.getSrc(); - Value dst = op.getResult(); - auto srcTy = cast(src.getType()); - auto dstTy = cast(dst.getType()); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - if (isa(srcLayout) && - isa(dstLayout)) { - return lowerMfmaToDotOperand(op, adaptor, rewriter); - } - return failure(); - } - -private: - LogicalResult - lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - RankedTensorType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - if (isMfmaToDotShortcut(srcTy, dstTy)) { - // vecSize is an number of sequential elements stored by one thread - // - For MFMA encoding (encoding of the result tensor of dot - // operation) it is 4 - // - For MFMA operand encoding it is - // dotOperandEncoding::kWidth, - // which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack - // = 1) - // - // For cases where these two values are equal MFMA and MFMA operand - // layouts are the same. - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - Value view = - packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } - return failure(); - } -}; } // namespace namespace mlir::triton::AMD { @@ -172,7 +122,6 @@ void populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index cece47227e..bce126ea4d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -38,8 +38,10 @@ struct DecomposeUnsupportedAMDConversions triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - isMfmaToDotShortcut); + auto isShortcut = + mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory)); + + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut); /* -------------------------------- */ // Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op` From bf6bd0b76c0c9d5366550769732d10e1a4e402d4 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 29 Oct 2024 18:05:28 +0000 Subject: [PATCH 04/13] [AMD] Skip scaled_dot tests for gfx11 and gfx12 (#5008) `scaled_dot` is not yet implemented on `gfx11` and `gfx12` so disable unit tests for now. --- python/test/unit/language/test_core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 43f4cc8c12..ced1a5352a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3343,6 +3343,9 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP") if mma == 16 and K == 64: pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") + arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in arch or "gfx12" in arch: + pytest.skip("scaled_dot not yet implemented for gfx11 and gfx12") @triton.jit def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, From ebce7f3a62af5242bbb3fe05876c5b3995eb2988 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Tue, 29 Oct 2024 19:31:23 -0400 Subject: [PATCH 05/13] Add string representation for AttrsDescriptor (#4888) The string representation allows PyTorch Inductor to serialize/derserialize the `AttrsDescriptor` to the `@triton.heuristics` block in the generated code. --- python/triton/backends/compiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index f2ba8eac80..cac42a6631 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -210,6 +210,9 @@ def get_property_key(val, align): return "1" return "N" + def __repr__(self): + return f"AttrsDescriptor.from_dict({self.to_dict()!r})" + @dataclass(frozen=True) class GPUTarget(object): From cfddb090981a49e54872c30325fcf54382704993 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Wed, 30 Oct 2024 01:11:33 -0700 Subject: [PATCH 06/13] [BACKEND][NVIDIA] Add Lowering for Shared-to-MMAv3-DotOp Copy (#5009) Allows for upcasting in DotOp encoding in RF. This lowering path is not currently in use; pending https://github.com/triton-lang/triton/pull/5003 --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 44 ++++--- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 21 +++- lib/Dialect/TritonGPU/IR/Dialect.cpp | 33 ++++-- .../Transforms/OptimizeDotOperands.cpp | 3 +- test/Conversion/tritongpu_to_llvm_hopper.mlir | 24 +++- test/TritonGPU/dot-operands.mlir | 8 +- test/TritonGPU/invalid-attributes.mlir | 14 ++- test/TritonGPU/loop-pipeline-hopper.mlir | 12 +- .../pipeline-hopper-remove-wait.mlir | 4 +- .../lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt | 2 +- .../ConvertLayoutOpToLLVM.cpp | 19 ++- ...v2.cpp => SharedToDotOperandMMAv2OrV3.cpp} | 108 +++++++++++++----- .../DotOpToLLVM/MMAv2.cpp | 14 ++- .../DotOpToLLVM/WGMMA.cpp | 5 + 14 files changed, 225 insertions(+), 86 deletions(-) rename third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/{SharedToDotOperandMMAv2.cpp => SharedToDotOperandMMAv2OrV3.cpp} (88%) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c8512fce57..382bc23182 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false. return get(context, vec, perPhase, maxPhase, order, CTALayout); } - // ---- begin Ampere ---- - if (mmaEnc.isAmpere()) { + // ---- begin Ampere & Hopper ---- + if (mmaEnc.isAmpere() || mmaEnc.isHopper()) { int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); perPhase = std::max(perPhase, 1); std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false. llvm_unreachable("invalid operand index"); } - // ---- begin version 3 ---- - if (mmaEnc.isHopper()) { - llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" - " is Hopper has not been implemented yet"); - return $_get(context, 1, 1, 1, order, CTALayout, true); - } - // ---- not implemented ---- llvm_unreachable("unsupported swizzling for provided MMA version"); }]>, @@ -1224,7 +1217,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: SmallVector getMMAv1Rep(int opIdx) const; SmallVector getMMAv1ShapePerWarp(int opIdx) const; int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2RepForOperand(ArrayRef shape, + SmallVector getMMAv2OrV3RepForOperand(ArrayRef shape, int bitwidth, int kWidth, int opIdx) const; bool supportReduction() const { @@ -1319,6 +1312,27 @@ The parent field is the layout of d. kWidth defines number of consecutive elements stored by one thread along k dimension. Some layouts do not use this parameter, either because they have a fixed number of elements along the K dim, or they use all elements of the tensor along the K dim. + +# WGMMA Notes +We require kWidth to be provided for Hopper because the dtype at loading might be +different from the dtype at WGMMA, due to casting. The kWidth is determined by the +dtype at WGMMA. + +The encoded tensor consists of operand A for possibly multiple wgmma instructions. +For each wgmma, each warp in a warp group feeds a single "warp matrix" +Each warp matrix consists of 2x2 "quads". +Each thread holds several elements in each quad. Right before a wgmma, +the sum of bitwidth of +the elements in each quad should add up to 32. + +These values are stored unrolled in `elements`. +The ordering of dimensions is as follows by convention: +batch (only 1 batch for Hopper currently) +matM (m-index of the "warp matrix") +matK (k-index of the "warp matrix") +quadK (k-index of the "quad" in the core matrix) +quadM (m-index of the "quad" in the core matrix) +vecIdx (index of the element in the quad; this is always along the k-dim) }]; let parameters = ( @@ -1329,16 +1343,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim ); let builders = [ - // Specially for MMAV1(Volta) AttrBuilder<(ins "unsigned":$opIdx, "Attribute":$parent, "Type":$eltTy), [{ NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); - if (!parentAttr || !parentAttr.isAmpere()) - return $_get(context, opIdx, parent, 0); + if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) + return $_get(context, opIdx, parent, 0); // For MMAV1 + // For MMAV2 and V3 unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); - unsigned MMAv2kWidth = 32 / bitwidth; - return $_get(context, opIdx, parent, MMAv2kWidth); + unsigned kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, kWidth); }]> ]; diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 470e8b32b5..1b7088870c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -11,6 +11,25 @@ using namespace mlir::triton::gpu; namespace mlir::triton::gpu { +namespace { + +bool isDotOpTensorAndPacked(Type srcTy) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return false; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!encoding) + return false; + auto parentEnc = dyn_cast(encoding.getParent()); + // By code convention, values for Hopper's dotOp-encoded tensors are not + // packed + if (!parentEnc || parentEnc.isHopper()) + return false; + return true; +} + +} // namespace + Type getElementType(Value value) { auto type = value.getType(); if (auto tensorType = dyn_cast(type)) @@ -33,7 +52,7 @@ SmallVector reorderValues(const SmallVector &values, Type inType, // If the parent of the dot operand is in block encoding, we don't need to // reorder elements auto parentEncoding = dyn_cast(ouEncoding.getParent()); - if (!parentEncoding) + if (!parentEncoding || parentEncoding.isHopper()) return values; size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 71506ecbb9..3b5316ecc0 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1074,13 +1074,18 @@ LogicalResult DotOperandEncodingAttr::verify( return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; } if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0 && !parentAttr.isAmpere()) + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) return emitError() << "triton_gpu.dot_op kWidth parameter can only be " - "non-zero for Ampere MMA parent"; - if (kWidth == 0 && parentAttr.isAmpere()) + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) return emitError() << "triton_gpu.dot_op kWidth parameter is mandatory for " - "Ampere MMA parent"; + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "triton_gpu.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; return success(); } @@ -2013,17 +2018,20 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } -SmallVector NvidiaMmaEncodingAttr::getMMAv2RepForOperand( +SmallVector NvidiaMmaEncodingAttr::getMMAv2OrV3RepForOperand( ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { + assert(isAmpere() || (isHopper() && opIdx == 0)); auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); + // {batch, m, n, k} + // Hopper path never uses the n value, since this method is only invoked + // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; int numRepBatch = rank == 3 ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) : 1; - assert(isAmpere()); if (opIdx == 0) return {numRepBatch, @@ -2038,6 +2046,7 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2RepForOperand( warpsPerCTA[rank - 1]))}; } } + unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); @@ -2045,12 +2054,18 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( int warpsPerCTAN = getWarpsPerCTA()[1]; // H100 if (isHopper()) { - return getTotalElemsPerThread(shape, eltTy); + assert(opIdx == 0); + auto instrMNK = getInstrShape(); + int repM = ceil(shapePerCTA[0], instrMNK[0] * warpsPerCTAM); + int repK = ceil(shapePerCTA[1], instrMNK[2]); + // For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds + // kWidth elements for each quadrant. WGMMA is repeated repM * repK times. + return 4 * kWidth * repM * repK; } // A100 if (isAmpere()) { - auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(), - kWidth, opIdx); + auto rep = getMMAv2OrV3RepForOperand( + shapePerCTA, eltTy.getIntOrFloatBitWidth(), kWidth, opIdx); if (opIdx == 0) return 4 * rep[0] * rep[1] * rep[2]; if (opIdx == 1) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 9f3d8fff49..4695984acf 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -286,8 +286,9 @@ struct MMAV3UseRegOperand dstEnc.getVersionMajor() != 3) return failure(); auto srcTy = cast(alloc.getSrc().getType()); + auto kWidth = 32 / srcTy.getElementTypeBitWidth(); auto dotOperandEnc = DotOperandEncodingAttr::get( - dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); + dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/kWidth); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); if (!matchMmaV3AndDotOperandLayout(srcTy, newTy)) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 113ec3cf66..65ab0194a9 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -97,9 +97,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: - tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } } @@ -114,10 +114,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} - tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { + tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : - tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + tt.return + } +} +// +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func @dot_reg_operand_upcast(%a_desc: !tt.memdesc<128x64xi8, #shared>, %b: !tt.memdesc<64x64xf16, #shared>, %acc: tensor<128x64xf32, #mma>) { + %a_dotop = triton_gpu.local_load %a_desc : !tt.memdesc<128x64xi8, #shared> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %res = triton_nvidia_gpu.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } } @@ -193,7 +207,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: prmt.b32 // CHECK: prmt.b32 tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { - %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 82fc1ddf7b..2bdc443671 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -164,8 +164,8 @@ tt.func @update_kwidth_slice( #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1> %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> @@ -180,8 +180,8 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index c8b3c2ef6b..26a7c0773b 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -2,7 +2,7 @@ // expected-error@+2 {{triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: 2}} #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked}> +#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}> // ----- @@ -12,19 +12,25 @@ // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> + +// ----- + +// expected-error@+2 {{triton_gpu.dot_op opIdx parameter must be 0 for Hopper MMA parent, since Hopper WGMMA only allows first operand to be in registers}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index d391be688c..2c2182154d 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -398,8 +398,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -481,7 +481,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 @@ -519,7 +519,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -624,7 +624,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 @@ -685,7 +685,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 74fd2e0555..a7064ea822 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -113,7 +113,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %115 = triton_nvidia_gpu.warp_group_dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared> - %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. // CHECK: triton_nvidia_gpu.warp_group_dot @@ -121,7 +121,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: triton_nvidia_gpu.warp_group_dot // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait // CHECK: scf.yield - %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index a944da1c83..b26c73b88d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -1,6 +1,6 @@ add_triton_library(TritonNVIDIAGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp ConvertLayoutOpToLLVM.cpp DotOpToLLVM/MMAv1.cpp DotOpToLLVM/MMAv2.cpp diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 96289bbb2e..3f3a2817de 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -34,13 +34,13 @@ Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj, } // namespace SharedToDotOperandMMAv1 -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread); -} +} // namespace SharedToDotOperandMMAv2OrV3 namespace { @@ -88,11 +88,20 @@ struct LocalLoadOpConversion auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); Value res; - if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 - res = SharedToDotOperandMMAv2::convertLayout( + + if (isOuter) { + assert(false && "MMA Layout does not support outer product"); + return res; + } + + if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3 + if (mmaLayout.isHopper()) + assert(dotOperandLayout.getOpIdx() == 0); + + res = SharedToDotOperandMMAv2OrV3::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); - } else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1 + } else if (mmaLayout.isVolta() && isMMA) { // tensor core v1 bool isMMAv1Row = mmaLayout.getMMAv1IsRow(dotOperandLayout.getOpIdx()); auto srcSharedLayout = cast(src.getType().getEncoding()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp similarity index 88% rename from third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp rename to third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 21c2bee584..6094a91118 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -25,6 +25,7 @@ class MMA16816SmemLoader { ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, int elemBytes, + int mmaElemBytes, bool isHopper, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, const Location &loc); @@ -67,6 +68,8 @@ class MMA16816SmemLoader { int perPhase; int maxPhase; int elemBytes; + int mmaElemBytes; + bool isHopper; ConversionPatternRewriter &rewriter; const Location &loc; MLIRContext *ctx{}; @@ -203,10 +206,10 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { // vecWidth // <-------> // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || *t0 ... *t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\ -// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | -// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height -// ... | -// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ +// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | +// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height +// ... | +// t28 ... t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ // --------------------------------------------- || -------------------------------------------- // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 // t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 @@ -364,6 +367,7 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; } else { // base pointers + // ptrs[k][...] holds `vec` pointers each for (quadK == k) std::array, 2> ptrs; for (int i = 0; i < vecWidth; i++) ptrs[0][i] = getPtr(ptrIdx + i); @@ -383,11 +387,13 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, i0 = add(i0, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); i1 = add(i1, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); } + // ii[m] holds the offset for (quadM == m) std::array ii = {i0, i1}; // load 4 32-bit values from shared memory // (equivalent to ldmatrix.x4) SmallVector> vptrs(4, SmallVector(vecWidth)); + // i iterates the 2x2 quads, m-first for (int i = 0; i < 4; ++i) for (int j = 0; j < vecWidth; ++j) { vptrs[i][j] = gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); @@ -402,7 +408,9 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, int canonWidth = (8 * elemBytes * inc) / canonBits; Type canonInt = int_ty(canonBits); std::array retElems; - retElems.fill(undef(vec_ty(canonInt, 32 / canonBits))); + // don't pack to 32b for Hopper + int vecSize = isHopper ? 1 : 32 / canonBits; + retElems.fill(undef(vec_ty(canonInt, vecSize))); for (int r = 0; r < 2; ++r) { for (int em = 0; em < 2 * vecWidth; em += inc) { int e = em % vecWidth; @@ -421,8 +429,11 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, } if (isActualTrans) std::swap(retElems[1], retElems[2]); - return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), - bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; + + auto iTy = isHopper ? int_ty(8 * elemBytes * inc) : i32_ty; + + return {bitcast(retElems[0], iTy), bitcast(retElems[1], iTy), + bitcast(retElems[2], iTy), bitcast(retElems[3], iTy)}; } } @@ -432,8 +443,9 @@ MMA16816SmemLoader::MMA16816SmemLoader( ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, - int elemBytes, ConversionPatternRewriter &rewriter, - const LLVMTypeConverter *typeConverter, const Location &loc) + int elemBytes, int mmaElemBytes, bool isHopper, + ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, + const Location &loc) : nPerWarp(nPerWarp), order(order.begin(), order.end()), warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder), kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()), @@ -441,17 +453,29 @@ MMA16816SmemLoader::MMA16816SmemLoader( matShape(matShape.begin(), matShape.end()), multiDimWarpId(multiDimWarpId.begin(), multiDimWarpId.end()), perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), - rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) { + mmaElemBytes(mmaElemBytes), isHopper(isHopper), rewriter(rewriter), + loc(loc), ctx(rewriter.getContext()) { + // If the current elemType width is different from the MMA elemType width, + // i.e. width-changing casting is done later in DotOp Layout... then, in the + // case of Hopper, the number of bytes held by each thread after loading will + // no longer be 32B. Hence this flag is required to stipulate different logic. + bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); + contiguousMatShape = matShape[order[0]]; stridedMatShape = matShape[order[1]]; stridedSmemOffset = smemStrides[order[1]]; smemBatchOffset = smemStrides[order[2]]; - vecWidth = 4 / elemBytes; + if (isHopperWidthChange) { + vecWidth = 4 / mmaElemBytes; + } else { + vecWidth = 4 / elemBytes; + } // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; nonKOrder = (kOrder == 2) ? 1 : 2; canUseLdmatrix = elemBytes == 2 || (!needTrans); canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth); + canUseLdmatrix = canUseLdmatrix && !isHopperWidthChange; if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed, @@ -504,10 +528,28 @@ Type getSharedMemTy(Type argType) { llvm::report_fatal_error("mma16816 data type not supported"); } +std::vector unpackInt(const std::vector &inValues, Type elTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + const int inBitWidth = inValues[0].getType().getIntOrFloatBitWidth(); + std::vector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(elTy); + auto vecType = + vec_ty(eltType, inBitWidth / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < inBitWidth / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + Value composeValuesToDotOperandLayoutStruct( const ValueTable &vals, int batch, int n0, int n1, const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, Type elTy, bool isHopper) { std::vector elems; for (int b = 0; b < batch; ++b) for (int m = 0; m < n0; ++m) @@ -519,6 +561,10 @@ Value composeValuesToDotOperandLayoutStruct( } assert(!elems.empty()); + if (isHopper) { + elems = unpackInt(elems, elTy, rewriter, loc, typeConverter); + } + Type elemTy = elems[0].getType(); MLIRContext *ctx = elemTy.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( @@ -544,18 +590,20 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, const int maxPhase = sharedLayout.getMaxPhase(); const int vecPhase = sharedLayout.getVec(); const int elemBytes = descTy.getElementTypeBitWidth() / 8; + const int mmaElemBytes = 4 / kWidth; + const bool isHopper = mmaLayout.getVersionMajor() == 3; auto order = sharedLayout.getOrder(); int nPerWarp = std::max(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); - // (a, b) is the coordinate. auto load = [=, &rewriter, &vals](int batch, int a, int b) { - MMA16816SmemLoader loader( - nPerWarp, warpsPerTile, sharedLayout.getOrder(), - mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides, - shapePerCTA /*tileShape*/, instrShape, matShape, multiDimWarpId, - perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); + MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(), + mmaLayout.getWarpsPerCTA(), kOrder, kWidth, + smemObj.strides, shapePerCTA /*tileShape*/, + instrShape, matShape, multiDimWarpId, perPhase, + maxPhase, elemBytes, mmaElemBytes, isHopper, + rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(lane, cSwizzleOffset); @@ -573,6 +621,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, auto [ha0, ha1, ha2, ha3] = loader.loadX4( batch, (kOrder == 2) ? a : b /*mat0*/, (kOrder == 2) ? b : a /*mat1*/, ptrs, matTy, getSharedMemTy(eltTy)); + if (!isA) std::swap(ha1, ha2); // the following is incorrect @@ -595,17 +644,22 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, MemDescType descTy, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread, bool isA) { + auto mmaLayout = mlir::cast(encoding.getParent()); + bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); int bitwidth = descTy.getElementTypeBitWidth(); - auto mmaLayout = mlir::cast(encoding.getParent()); + // For Hopper WGMMA, the sum of bitwidth of the elements in each quad should + // add up to 32. We use kWidth to compute the element bitwidth of the input to + // WGMMA, which could be different from `bitwidth` due to later casting. + int mmaBitwidth = isHopper ? (32 / encoding.getKWidth()) : bitwidth; ValueTable vals; - int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; - int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; + int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth; + int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; int kWidth = encoding.getKWidth(); - auto numRep = mmaLayout.getMMAv2RepForOperand(shapePerCTA, bitwidth, kWidth, - encoding.getOpIdx()); + auto numRep = mmaLayout.getMMAv2OrV3RepForOperand( + shapePerCTA, bitwidth, kWidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto order = triton::gpu::getOrder(mmaLayout); @@ -616,7 +670,6 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, delinearize(rewriter, loc, warp, warpsPerCTA, order); Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); int warpsPerTile; - auto rank = shapePerCTA.size(); Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16)); Value warpN = urem(multiDimWarpId[2], i32_val(shapePerCTA[2] / 8)); if (isA) @@ -652,7 +705,8 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, // Format the values to LLVM::Struct to passing to mma codegen. return composeValuesToDotOperandLayoutStruct( - vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter); + vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter, + descTy.getElementType(), /*unpack=*/isHopper); } template @@ -764,7 +818,7 @@ getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, return expandedSmemObj; } -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, @@ -785,4 +839,4 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, expandedSmemObj, typeConverter, thread, false); } } -} // namespace SharedToDotOperandMMAv2 +} // namespace SharedToDotOperandMMAv2OrV3 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index c2940a0438..b03fb0989d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -393,13 +393,15 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); - auto repA = cast(dotOpA.getParent()) - .getMMAv2RepForOperand(aShapePerCTA, bitwidth, - dotOpA.getKWidth(), dotOpA.getOpIdx()); + auto repA = + cast(dotOpA.getParent()) + .getMMAv2OrV3RepForOperand(aShapePerCTA, bitwidth, dotOpA.getKWidth(), + dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); - auto repB = cast(dotOpB.getParent()) - .getMMAv2RepForOperand(bShapePerCTA, bitwidth, - dotOpB.getKWidth(), dotOpB.getOpIdx()); + auto repB = + cast(dotOpB.getParent()) + .getMMAv2OrV3RepForOperand(bShapePerCTA, bitwidth, dotOpB.getKWidth(), + dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 1bb55373e0..2b9b4f159b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -442,6 +442,11 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, if (aSharedLayout) { a = aLoader.smemLoad(m, k, rewriter, loc); } else { + auto aDotOpEnc = + cast(aTensorTy.getEncoding()); + assert(aDotOpEnc.getKWidth() == + 32 / aTensorTy.getElementTypeBitWidth()); + unsigned regASize = (instrShape[0] * instrShape[2]) / 32; llvm::SmallVector regA = loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize, From 0591b3756bd4143b7163235c0eca4d718948e982 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 30 Oct 2024 14:56:52 +0100 Subject: [PATCH 07/13] Don't specify `-A x64` option and reuse cmake build type config for Windows (#5014) The `-A` argument is not compatible with the Ninja generator. Signed-off-by: Anatoly Myachev --- python/setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/setup.py b/python/setup.py index 714668462f..f73539de28 100644 --- a/python/setup.py +++ b/python/setup.py @@ -429,12 +429,10 @@ def build_extension(self, ext): cfg = get_build_type() build_args = ["--config", cfg] + cmake_args += [f"-DCMAKE_BUILD_TYPE={cfg}"] if platform.system() == "Windows": cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - if sys.maxsize > 2**32: - cmake_args += ["-A", "x64"] else: - cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) build_args += ['-j' + max_jobs] From 23c9ec169ce99e6476e8f3ac25656472caa762f8 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:54:06 +0000 Subject: [PATCH 08/13] [Frontend][Backend] Implement support for scale_dot(-, bf16) (#4996) In the passing we also improve a few other things: - Now `scaled_dot` accepts both uint8/uint16 fp8/bf16 as inputs (before you had to cast it to uint8, which was weird when extending it to bf16). - Add `scaled_dot` to the docs and improve the docs overall (have not render them, might need a few further tweaks) --- docs/python-api/triton.language.rst | 1 + .../Dialect/Triton/IR/TritonAttrDefs.td | 9 +-- include/triton/Dialect/Triton/IR/TritonOps.td | 16 ++--- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 2 +- lib/Dialect/TritonGPU/IR/Ops.cpp | 8 +-- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 61 ++++++++----------- python/src/ir.cc | 19 +++--- python/test/unit/language/test_core.py | 31 +++++----- python/triton/language/core.py | 14 +++-- python/triton/language/semantic.py | 49 +++++++++------ test/TritonGPU/accelerate-matmul.mlir | 8 +-- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 2 +- .../AccelerateAMDMatmul.cpp | 20 +++--- .../UpcastMXFPToLLVM.cpp | 2 +- 14 files changed, 128 insertions(+), 114 deletions(-) diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index ecd0fb3b94..415091a100 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -59,6 +59,7 @@ Linear Algebra Ops :nosignatures: dot + dot_scaled Memory/Pointer Ops diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index f3159338bd..04e4c25fd6 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } -// Type for F8F6F4 kind of floats. -def TT_F8F6F4TypeAttr : I32EnumAttr< - "F8F6F4Type", "", +// Type for ScaleDotElemType kind of floats. +def TT_ScaleDotElemTypeAttr : I32EnumAttr< + "ScaleDotElemType", "", [ I32EnumAttrCase<"E4M3", 0, "e4m3">, I32EnumAttrCase<"E5M2", 1, "e5m2">, I32EnumAttrCase<"E2M3", 2, "e2m3">, I32EnumAttrCase<"E3M2", 3, "e3m2">, - I32EnumAttrCase<"E2M1", 4, "e2m1"> + I32EnumAttrCase<"E2M1", 4, "e2m1">, + I32EnumAttrCase<"BF16", 5, "bf16"> ]>{ let cppNamespace = "::mlir::triton"; diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index d3bb95ca95..2c3a1bf714 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, let arguments = ( ins - // inputs are integer types as they are packed types and we currently - // don't have a representation for those. - TT_IntTensor:$lhs, - TT_IntTensor:$rhs, + // inputs are floats if we have a type for them, otherwise (fp4), + // they are packed in pairs in an I8Tensor + RankedTensorOf<[TT_Float,I8]>:$lhs, + RankedTensorOf<[TT_Float,I8]>:$rhs, TT_FloatTensor:$c, - TT_IntTensor:$lhs_scale, - Optional:$rhs_scale, - TT_F8F6F4TypeAttr:$lhs_type, - TT_F8F6F4TypeAttr:$rhs_type + RankedTensorOf<[I8]>:$lhs_scale, + Optional>:$rhs_scale, + TT_ScaleDotElemTypeAttr:$lhs_type, + TT_ScaleDotElemTypeAttr:$rhs_type ); let results = (outs TT_FloatTensor:$d); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index a290cb2031..6299ee6ed4 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods(encoding); auto newVEncoding = DotOperandEncodingAttr::get( ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index a2d4012bf2..3ddab364d7 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -415,22 +415,12 @@ class ScaledBlockedToMMAv2 auto aType = dotOp.getLhsType(); auto bType = dotOp.getRhsType(); - auto enumToType = [&rewriter](F8F6F4Type type) { - switch (type) { - case F8F6F4Type::E4M3: - return rewriter.getFloat8E4M3FNType(); - case F8F6F4Type::E5M2: - return rewriter.getFloat8E5M2Type(); - default: - llvm_unreachable("unexpected type"); - } - }; - - assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 || - aType == F8F6F4Type::E2M1) && + assert((aType == ScaleDotElemType::E4M3 || + aType == ScaleDotElemType::E5M2 || + aType == ScaleDotElemType::E2M1) && "NYI: lhs supports fp4 or fp8"); - assert(bType == F8F6F4Type::E4M3 || - bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8"); + assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 || + bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16"); // TODO run accelerate matmul on A and B first to choose their layouts // Set return type @@ -454,11 +444,12 @@ class ScaledBlockedToMMAv2 auto newAcc = rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); - auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType]( - TypedValue v, int idx, - F8F6F4Type type) -> TypedValue { + auto toMMABf16 = + [&newRetType, &rewriter, + &ctx](TypedValue v, int idx, + ScaleDotElemType type) -> TypedValue { auto vType = v.getType(); - if (type == F8F6F4Type::E2M1) { + if (type == ScaleDotElemType::E2M1) { // A bit too dynamically typed... // perhaps return ints in both cases? @@ -469,23 +460,23 @@ class ScaledBlockedToMMAv2 vType.getShape(), vType.getElementType(), newVEncoding); return rewriter.create(v.getLoc(), newVType, v); } else { - assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + assert(type == ScaleDotElemType::E5M2 || + type == ScaleDotElemType::E4M3 || + type == ScaleDotElemType::BF16); auto newVEncoding = DotOperandEncodingAttr::get( ctx, idx, newRetType.getEncoding(), /*kWidth=*/8); auto newVType = RankedTensorType::get( vType.getShape(), vType.getElementType(), newVEncoding); v = rewriter.create(v.getLoc(), newVType, v); - // Bitcast - auto vTypeFp8 = RankedTensorType::get(vType.getShape(), - enumToType(type), newVEncoding); - v = cast>( - rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); - - // Convert to bf16 - auto vTypeBf16 = RankedTensorType::get( - vType.getShape(), rewriter.getBF16Type(), newVEncoding); - return rewriter.create(v.getLoc(), vTypeBf16, v); + if (type == ScaleDotElemType::BF16) { + return v; + } else { + // Convert to bf16 + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); + return rewriter.create(v.getLoc(), vTypeBf16, v); + } } }; a = toMMABf16(a, 0, aType); @@ -515,11 +506,11 @@ class ScaledBlockedToMMAv2 auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout); - auto newScaleType = RankedTensorType::get(scale.getType().getShape(), - scale.getType().getElementType(), - newScaleEncoding); - scale = - rewriter.create(scale.getLoc(), newScaleType, scale); + auto newScaleDotElemType = RankedTensorType::get( + scale.getType().getShape(), scale.getType().getElementType(), + newScaleEncoding); + scale = rewriter.create(scale.getLoc(), + newScaleDotElemType, scale); auto scaledA = rewriter.create( dotOp.getLoc(), a, scale, dotOp.getLhsType()); diff --git a/python/src/ir.cc b/python/src/ir.cc index 9945c61882..cce7c87e8d 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -205,12 +205,13 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .export_values(); - py::enum_(m, "F8F6F4TY", py::module_local()) - .value("E4M3", F8F6F4Type::E4M3) - .value("E5M2", F8F6F4Type::E5M2) - .value("E2M3", F8F6F4Type::E2M3) - .value("E3M2", F8F6F4Type::E3M2) - .value("E2M1", F8F6F4Type::E2M1) + py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) + .value("E4M3", ScaleDotElemType::E4M3) + .value("E5M2", ScaleDotElemType::E5M2) + .value("E2M3", ScaleDotElemType::E2M3) + .value("E3M2", ScaleDotElemType::E3M2) + .value("E2M1", ScaleDotElemType::E2M1) + .value("BF16", ScaleDotElemType::BF16) .export_values(); py::class_(m, "context", py::module_local()) @@ -1423,9 +1424,9 @@ void init_triton_ir(py::module &&m) { }) .def("create_dot_scaled", [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, - F8F6F4Type lhs_format, mlir::Value &rhs, - std::optional &rhs_scale, F8F6F4Type rhs_format, - mlir::Value &c) -> mlir::Value { + ScaleDotElemType lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, + ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value { return self.create( c.getType(), lhs, rhs, c, lhs_scale, rhs_scale.value_or(Value()), lhs_format, rhs_format); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ced1a5352a..7a8debe0dd 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3330,7 +3330,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) for type_a in ["e2m1", "e4m3", "e5m2"] - for type_b in ["e4m3", "e5m2"] + for type_b in ["e4m3", "e5m2", "bf16"] for mma in ([32, 16] if is_hip() else [16]) for kpack in ([1, 2] if is_hip() else [1])]) def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device): @@ -3351,7 +3351,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr): - tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8") + tl.static_assert((type_b == "e4m3" or type_b == "e5m2") or type_b == "bf16", "type_b must be fp8 or bf16") IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2" DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2 PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR @@ -3442,7 +3442,7 @@ def mxfp_to_bf16_kernel( def dot_scale_ref(x, scale, y, type_x, type_y): e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] - type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] + type_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type_y] comp_dtype = torch.bfloat16 @@ -3455,7 +3455,7 @@ def dot_scale_ref(x, scale, y, type_x, type_y): mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) assert x_upcast.isfinite().all() - y_upcast = y.view(type_fp8_y).to(comp_dtype) + y_upcast = y.view(type_y).to(comp_dtype) class AccumulateInFp32: @@ -3467,28 +3467,30 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value with AccumulateInFp32(): - return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype)) + return torch.matmul(x_upcast, y_upcast) torch.manual_seed(0) - def create_uint8(shape, col_major=False, max_val=255): + def make_arg(shape, ty, col_major=False, max_val=255): if col_major: shape = shape[:-2] + (shape[-1], shape[-2]) - ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) + if ty == "bf16": + ret = torch.randn(shape, dtype=torch.bfloat16, device=device) + # Clamp to avoid relative error issues + ret.clamp_(-2**15, 2**15 - 1) + else: + ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) if col_major: ret = ret.mT return ret DIV_FACTOR = 2 if type_a == "e2m1" else 1 - x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) - y = create_uint8((K, N), col_major=col_b) + x = make_arg((M, K // DIV_FACTOR), type_a, col_major=col_a) + y = make_arg((K, N), type_b, col_major=col_b) # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) - # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow - m_bytes = int(type_a[1]) - bias_type_a = 1 << (m_bytes - 1) - 1 - max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a - scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64) + # Max scale= 2**15 + scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15) def make_finite(x, dtype): # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and @@ -3513,7 +3515,6 @@ def make_finite(x, dtype): z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) - # generous rtol as we are sampling the whole range of floats torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) # make sure ld/st are vectorized diff --git a/python/triton/language/core.py b/python/triton/language/core.py index c9f22e3991..a95b65a306 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1556,15 +1556,17 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, lhs and rhs use microscaling formats described here: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf :param lhs: The first tensor to be multiplied. - :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :type lhs: 2D tensor representing fp4 or fp8 elements packed into uint8 for fp4 inputs, or in uint8 or the corresponding fp8 type for fp8 inputs. :param lhs_scale: Scale factor for lhs tensor. - :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :type lhs_scale: e8m0 type represented as an uint8 tensor. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code: `e5m2`}. + :type lhs_format: str :param rhs: The second tensor to be multiplied. - :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :type rhs: 2D tensor representing fp8 or bf16 elements in uint8 or the corresponding fp8 type for fp8 inputs or bf16 for bf16 inputs. :param rhs_scale: Scale factor for rhs tensor. - :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :type rhs_scale: e8m0 type represented as an uint8 tensor. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code: `e5m2`, :code:`bf16`}. + :type rhs_format: str :param acc: The accumulator tensor. If not None, the result is added to this tensor. """ out_dtype = _constexpr_to_value(out_dtype) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index be157c5b46..a9af8c8d80 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1527,33 +1527,48 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona ret_ty) -def _str_to_fp_type(float_format: Optional[str]): - if float_format == 'e4m3': - return ir.F8F6F4TY.E4M3 - if float_format == 'e5m2': - return ir.F8F6F4TY.E5M2 - if float_format == 'e2m3': - return ir.F8F6F4TY.E2M3 - if float_format == 'e3m2': - return ir.F8F6F4TY.E3M2 - if float_format == 'e2m1': - return ir.F8F6F4TY.E2M1 - raise ValueError(f"Invalid float format: {float_format}.") - - -def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], - rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: +def _str_to_fp_type(float_format: str): + ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None) + if ty_enum is None: + raise ValueError(f"Invalid float format: {float_format}.") + return ty_enum + + +def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): + """ + If float_format is subbyte, make sure it's packed as uint8 and return it. + Otherwise, return a tensor (perhaps bitcasting) of the specified float format. + """ + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return bitcast(val, triton_ty, builder) + + +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() #TODO: validate types. lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}" - assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}" + assert rhs_format in ("e4m3", "e5m2", "bf16"), f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None assert rhs_scale_is_none, "NYI: rhs_scale not supported" + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) + rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) M = lhs.type.shape[-2] K, N = rhs.type.shape[-2:] diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 85b37f3ed3..420a9d5c2c 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -164,21 +164,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -// Verify that dot_scaled (mxfp8 x fp8) decomposes as expected +// Verify that dot_scaled (mxfp4 x bf16) decomposes as expected #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_scaled tt.func @dot_scaled( - %a: tensor<128x64xi8, #blocked2>, + %a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, - %b: tensor<64x128xi8, #blocked>) + %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { // CHECK: triton_gpu.upcast_mxfp // CHECK: tt.dot %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e4m3 rhs = e4m3 : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked> -> tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index 289ceb61a5..b35e28272c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -43,7 +43,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto fpType = op.getFpType(); - if (!(fpType == F8F6F4Type::E4M3 || fpType == F8F6F4Type::E5M2)) + if (!(fpType == ScaleDotElemType::E4M3 || fpType == ScaleDotElemType::E5M2)) return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases"); Location loc = op.getLoc(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 3aa009c363..201a7b0212 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -504,12 +504,14 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { TensorValue a = dotOp.getLhs(); TensorValue b = dotOp.getRhs(); TensorValue aScale = dotOp.getLhsScale(); - F8F6F4Type aElemType = dotOp.getLhsType(); - F8F6F4Type bElemType = dotOp.getRhsType(); + ScaleDotElemType aElemType = dotOp.getLhsType(); + ScaleDotElemType bElemType = dotOp.getRhsType(); - if (!(aElemType == F8F6F4Type::E4M3 || aElemType == F8F6F4Type::E5M2)) + if (!(aElemType == ScaleDotElemType::E4M3 || + aElemType == ScaleDotElemType::E5M2)) return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8 LHS"); - if (!(bElemType == F8F6F4Type::E4M3 || bElemType == F8F6F4Type::E5M2)) + if (!(bElemType == ScaleDotElemType::E4M3 || + bElemType == ScaleDotElemType::E5M2)) return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8 RHS"); MLIRContext *ctx = dotOp.getContext(); @@ -553,11 +555,11 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { // OCP mxfp8 requires implementations to follow OCP fp8 elements. We are // doing software emulation using bf16 here, so we map to OCP fp8 f8E4M3FN // and f8E5M2. - auto enumToType = [&rewriter](F8F6F4Type type) { + auto enumToType = [&rewriter](ScaleDotElemType type) { switch (type) { - case F8F6F4Type::E4M3: + case ScaleDotElemType::E4M3: return rewriter.getFloat8E4M3FNType(); - case F8F6F4Type::E5M2: + case ScaleDotElemType::E5M2: return rewriter.getFloat8E5M2Type(); default: llvm_unreachable("unexpected fp type"); @@ -565,8 +567,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { }; auto toMMABf16 = [&](TensorValue v, int idx, - F8F6F4Type type) -> TensorValue { - assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + ScaleDotElemType type) -> TensorValue { + assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3); auto vType = v.getType(); auto newVEnc = DotOperandEncodingAttr::get( ctx, idx, newRetType.getEncoding(), kWdith); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 722bf56cd0..136b696132 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -103,7 +103,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - if (fpType == F8F6F4Type::E2M1) { + if (fpType == ScaleDotElemType::E2M1) { xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); } From 61eb94e5cfca1f15f410694bb0034f3db7ebfb9d Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 30 Oct 2024 18:46:41 +0100 Subject: [PATCH 09/13] [INTERPRETER] Make sure interpreter works with float16 by reusing NumPy HALF-related code (#5010) Closes #4992 --- python/src/interpreter.cc | 221 ++++++++++++++++++++++++- python/test/unit/language/test_core.py | 2 - 2 files changed, 220 insertions(+), 3 deletions(-) diff --git a/python/src/interpreter.cc b/python/src/interpreter.cc index 1dc2dac8d8..747a0cc171 100644 --- a/python/src/interpreter.cc +++ b/python/src/interpreter.cc @@ -5,12 +5,17 @@ #include #include #include +#include #include namespace py = pybind11; namespace { +struct npy_half { + uint16_t value; +}; + enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; std::mutex atomic_op_guard; @@ -83,6 +88,211 @@ template T atomic_fadd(T *loc, T value, std::memory_order order) { return old_value; } +/** Create a value of type `To` from the bits of `from`. + * + * similar to `std::bit_cast` but compatible with C++17, + * should perform similar to `*reinterpret_cast(&from)` + * or through punning without expecting any undefined behaviors. + * + * Note: taken from + * https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/utils.hpp#L32 + * with simplification. + */ +template +inline To BitCast(const From &from) noexcept { + static_assert(sizeof(To) == sizeof(From), + "both data types must have the same size"); + + static_assert(std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + "both data types must be trivially copyable"); + + To to; + memcpy(&to, &from, sizeof(from)); + return to; +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L14 +template +inline uint16_t FromFloatBits(uint32_t f) { + uint32_t f_exp, f_sig; + uint16_t h_sgn, h_exp, h_sig; + + h_sgn = (uint16_t)((f & 0x80000000u) >> 16); + f_exp = (f & 0x7f800000u); + + /* Exponent overflow/NaN converts to signed inf/NaN */ + if (f_exp >= 0x47800000u) { + if (f_exp == 0x7f800000u) { + /* Inf or NaN */ + f_sig = (f & 0x007fffffu); + if (f_sig != 0) { + /* NaN - propagate the flag in the significand... */ + uint16_t ret = (uint16_t)(0x7c00u + (f_sig >> 13)); + /* ...but make sure it stays a NaN */ + if (ret == 0x7c00u) { + ret++; + } + return h_sgn + ret; + } else { + /* signed inf */ + return (uint16_t)(h_sgn + 0x7c00u); + } + } else { + if constexpr (gen_overflow) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error("overflow to signed inf"); + } + return (uint16_t)(h_sgn + 0x7c00u); + } + } + + /* Exponent underflow converts to a subnormal half or signed zero */ + if (f_exp <= 0x38000000u) { + /* + * Signed zeros, subnormal floats, and floats with small + * exponents all convert to signed zero half-floats. + */ + if (f_exp < 0x33000000u) { + if constexpr (gen_underflow) { + /* If f != 0, it underflowed to 0 */ + if ((f & 0x7fffffff) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + return h_sgn; + } + /* Make the subnormal significand */ + f_exp >>= 23; + f_sig = (0x00800000u + (f & 0x007fffffu)); + if constexpr (gen_underflow) { + /* If it's not exactly represented, it underflowed */ + if ((f_sig & (((uint32_t)1 << (126 - f_exp)) - 1)) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + /* + * Usually the significand is shifted by 13. For subnormals an + * additional shift needs to occur. This shift is one for the largest + * exponent giving a subnormal `f_exp = 0x38000000 >> 23 = 112`, which + * offsets the new first bit. At most the shift can be 1+10 bits. + */ + f_sig >>= (113 - f_exp); + /* Handle rounding by adding 1 to the bit beyond half precision */ + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. However, the (113 - f_exp) + * shift can lose up to 11 bits, so the || checks them in the original. + * In all other cases, we can just add one. + */ + if (((f_sig & 0x00003fffu) != 0x00001000u) || (f & 0x000007ffu)) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp from zero to one and h_sig will be zero. + * This is the correct result. + */ + return (uint16_t)(h_sgn + h_sig); + } + + /* Regular case with no overflow or underflow */ + h_exp = (uint16_t)((f_exp - 0x38000000u) >> 13); + /* Handle rounding by adding 1 to the bit beyond half precision */ + f_sig = (f & 0x007fffffu); + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. In all other cases, we do. + */ + if ((f_sig & 0x00003fffu) != 0x00001000u) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp by one and h_sig will be zero. This is the + * correct result. h_exp may increment to 15, at greatest, in + * which case the result overflows to a signed inf. + */ + if constexpr (gen_overflow) { + h_sig += h_exp; + if (h_sig == 0x7c00u) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error(""); + } + return h_sgn + h_sig; + } else { + return h_sgn + h_exp + h_sig; + } +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L269 +constexpr uint32_t ToFloatBits(uint16_t h) { + uint16_t h_exp = (h & 0x7c00u); + uint32_t f_sgn = ((uint32_t)h & 0x8000u) << 16; + switch (h_exp) { + case 0x0000u: { // 0 or subnormal + uint16_t h_sig = (h & 0x03ffu); + // Signed zero + if (h_sig == 0) { + return f_sgn; + } + // Subnormal + h_sig <<= 1; + while ((h_sig & 0x0400u) == 0) { + h_sig <<= 1; + h_exp++; + } + uint32_t f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23; + uint32_t f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13; + return f_sgn + f_exp + f_sig; + } + case 0x7c00u: // inf or NaN + // All-ones exponent and a copy of the significand + return f_sgn + 0x7f800000u + (((uint32_t)(h & 0x03ffu)) << 13); + default: // normalized + // Just need to adjust the exponent and shift + return f_sgn + (((uint32_t)(h & 0x7fffu) + 0x1c000u) << 13); + } +} + +npy_half npy_float_to_half(float f) { + return {FromFloatBits(BitCast(f))}; +} + +float npy_half_to_float(npy_half h) { + return BitCast(ToFloatBits(h.value)); +} + +template <> +npy_half atomic_fadd(npy_half *loc, npy_half value, + std::memory_order order) { + npy_half old_value; + + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = npy_float_to_half(npy_half_to_float(old_value) + + npy_half_to_float(value)); + + return old_value; +} + class AtomicOp { public: AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order) @@ -370,6 +580,15 @@ template struct OpCreator { } }; +template <> template <> void OpCreator::create() { + if (!atomic_op && dtype.char_() == 'e') { // float16 + // workaround until https://github.com/pybind/pybind11/issues/4061 is + // implemented + atomic_op = std::make_unique>( + ptr, val, ret, mask, numel, order); + } +}; + template std::unique_ptr makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, @@ -476,7 +695,7 @@ void init_triton_interpreter(py::module &&m) { switch (rmw_op) { MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) - MAKE_ATOMIC_RMW_OP(RMWOp::FADD, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, npy_half, float, double) MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7a8debe0dd..65199ea8c2 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1458,8 +1458,6 @@ def kernel(X): for num_ctas in num_ctas_list for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']]) def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device): - if is_interpreter() and dtype_x_str == 'float16': - pytest.skip('float16 atomic_add does not work in the interpreter mode') shape0, shape1 = shape # triton kernel From 018c139d2b843c29c3b4d4d2e6f3e672b8ff0b3a Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 30 Oct 2024 19:13:39 +0100 Subject: [PATCH 10/13] Allow windows cuda files to be used in `setup.py` (#5015) Example: ```python # On Windows >>> sysconfig.get_config_var("EXE") '.exe' # On Linux >>> sysconfig.get_config_var("EXE") '' ``` --------- Signed-off-by: Anatoly Myachev --- python/setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/setup.py b/python/setup.py index f73539de28..69d6830ae1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -497,8 +497,9 @@ def get_platform_dependent_src_path(subdir): if int(version_major) >= 12 and int(version_minor1) >= 5 else subdir)(*version.split('.'))) +exe_extension = sysconfig.get_config_var("EXE") download_and_copy( - name="ptxas", src_path="bin/ptxas", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH", + name="ptxas", src_path=f"bin/ptxas{exe_extension}", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH", version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version: ((lambda version_major, version_minor1, version_minor2: f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/{system}-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2" @@ -507,7 +508,7 @@ def get_platform_dependent_src_path(subdir): (*version.split('.')))) download_and_copy( name="cuobjdump", - src_path="bin/cuobjdump", + src_path=f"bin/cuobjdump{exe_extension}", dst_path="bin/cuobjdump", variable="TRITON_CUOBJDUMP_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"], @@ -516,7 +517,7 @@ def get_platform_dependent_src_path(subdir): ) download_and_copy( name="nvdisasm", - src_path="bin/nvdisasm", + src_path=f"bin/nvdisasm{exe_extension}", dst_path="bin/nvdisasm", variable="TRITON_NVDISASM_PATH", version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"], From ef319c8fd086ab90b63de30df7896f0a6e9a11b7 Mon Sep 17 00:00:00 2001 From: Saagar Jha Date: Wed, 30 Oct 2024 16:42:21 -0700 Subject: [PATCH 11/13] Fix formatting in docs for triton.language.dot (#5020) --- python/triton/language/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index a95b65a306..3a7544c3c4 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1523,9 +1523,9 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i where the first dimension of each block represents the batch dimension. :param input: The first tensor to be multiplied. - :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} :param other: The second tensor to be multiplied. - :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} :param acc: The accumulator tensor. If not None, the result is added to this tensor. :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`} :param input_precision: How to exercise the Tensor Cores for f32 x f32. If @@ -1559,13 +1559,13 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, :type lhs: 2D tensor representing fp4 or fp8 elements packed into uint8 for fp4 inputs, or in uint8 or the corresponding fp8 type for fp8 inputs. :param lhs_scale: Scale factor for lhs tensor. :type lhs_scale: e8m0 type represented as an uint8 tensor. - :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code: `e5m2`}. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`}. :type lhs_format: str :param rhs: The second tensor to be multiplied. :type rhs: 2D tensor representing fp8 or bf16 elements in uint8 or the corresponding fp8 type for fp8 inputs or bf16 for bf16 inputs. :param rhs_scale: Scale factor for rhs tensor. :type rhs_scale: e8m0 type represented as an uint8 tensor. - :param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code: `e5m2`, :code:`bf16`}. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code:`e5m2`, :code:`bf16`}. :type rhs_format: str :param acc: The accumulator tensor. If not None, the result is added to this tensor. """ From 6693ddd24445378263936bd2a9414e3cd0b1fe49 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Wed, 30 Oct 2024 17:08:32 -0700 Subject: [PATCH 12/13] Ignore autotune runs failed with PTXAS error (#5017) --- python/triton/runtime/autotuner.py | 4 ++-- python/triton/runtime/errors.py | 10 ++++++++++ third_party/nvidia/backend/compiler.py | 7 ++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index be02d61a43..d735aeb884 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -7,7 +7,7 @@ from typing import Dict from .jit import KernelInterface -from .errors import OutOfResources +from .errors import OutOfResources, PTXASError from .driver import driver @@ -157,7 +157,7 @@ def kernel_call(): try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure): + except (OutOfResources, CompileTimeAssertionFailure, PTXASError): return [float("inf"), float("inf"), float("inf")] def run(self, *args, **kwargs): diff --git a/python/triton/runtime/errors.py b/python/triton/runtime/errors.py index 4dce917670..1a8046430e 100644 --- a/python/triton/runtime/errors.py +++ b/python/triton/runtime/errors.py @@ -24,3 +24,13 @@ def __str__(self) -> str: def __reduce__(self): # this is necessary to make CompilationError picklable return (type(self), (self.required, self.limit, self.name)) + + +class PTXASError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + error_message = self.error_message or "" + return f"PTXAS error: {error_message}" diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index f8f0486d8f..3b95b3b3f2 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -1,5 +1,6 @@ from triton.backends.compiler import BaseBackend, GPUTarget from triton._C.libtriton import ir, passes, llvm, nvidia +from triton.runtime.errors import PTXASError from dataclasses import dataclass import functools @@ -361,9 +362,9 @@ def make_cubin(src, metadata, opt, capability): else: error = f'`ptxas` failed with error code {e.returncode}' - raise RuntimeError(f'{error}\n' - f'`ptxas` stderr:\n{log}\n' - f'Repro command: {" ".join(ptxas_cmd)}\n') + raise PTXASError(f"{error}\n" + f"`ptxas` stderr:\n{log}\n" + f'Repro command: {" ".join(ptxas_cmd)}\n') with open(fbin, 'rb') as f: cubin = f.read() From 4f6f76874ff623562903d5452d499cae3d40d448 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 30 Oct 2024 21:05:03 -0500 Subject: [PATCH 13/13] [AMD] Reland sinking the 2nd tt.load after local_load's (#4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert https://github.com/triton-lang/triton/pull/4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64 --- .../amd/amd-reorder-instructions.mlir | 423 ------------------ test/TritonGPU/amd/amd-sched-2nd-load.mlir | 211 +++++++++ .../ReorderInstructions.cpp | 92 +++- 3 files changed, 297 insertions(+), 429 deletions(-) create mode 100644 test/TritonGPU/amd/amd-sched-2nd-load.mlir diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 686e5a24e8..5dfd0f2a5f 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -460,429 +460,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } } -// ----- -// This test ensures that loads will not be moved across `for` loops. - -// CHECK-LABEL: tt.func public @_attn_bwd -// CHECK: tt.load -// CHECK: tt.load -// CHECK: scf.for -// CHECK: } -// CHECK: scf.for -// CHECK: } -// Moved before the independent `tt.store` ops but not before the `for` ops. -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.store -// CHECK: tt.store -// CHECK: scf.for -// CHECK: } -// CHECK: scf.for -// CHECK: } -// CHECK: tt.store - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -#mma1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> -#shared3 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @_attn_bwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c-1_i32 = arith.constant -1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma> - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %c32_i32 = arith.constant 32 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c16_i32 = arith.constant 16 : i32 - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_2 = arith.constant dense<0.693147182> : tensor<128x64xf32, #mma> - %0 = tt.get_program_id z : i32 - %1 = arith.muli %0, %arg14 : i32 - %2 = arith.extsi %1 : i32 to i64 - %3 = arith.remsi %0, %arg13 : i32 - %4 = arith.muli %arg11, %3 : i32 - %5 = arith.divsi %0, %arg13 : i32 - %6 = arith.muli %arg10, %5 : i32 - %7 = arith.addi %4, %6 : i32 - %8 = arith.extsi %7 : i32 to i64 - %9 = tt.get_program_id x : i32 - %10 = tt.addptr %arg0, %8 : !tt.ptr, i64 - %11 = tt.addptr %arg1, %8 : !tt.ptr, i64 - %12 = tt.addptr %arg2, %8 : !tt.ptr, i64 - %13 = tt.addptr %arg4, %8 : !tt.ptr, i64 - %14 = tt.addptr %arg5, %8 : !tt.ptr, i64 - %15 = tt.addptr %arg6, %8 : !tt.ptr, i64 - %16 = tt.addptr %arg7, %8 : !tt.ptr, i64 - %17 = tt.addptr %arg8, %2 : !tt.ptr, i64 - %18 = tt.addptr %arg9, %2 : !tt.ptr, i64 - %19 = arith.muli %9, %c128_i32 : i32 - %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %25 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %26 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %27 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %28 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %29 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %30 = arith.addi %25, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %31 = arith.addi %26, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %32 = arith.addi %27, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %33 = arith.addi %28, %23 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %34 = arith.addi %29, %24 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %35 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> - %36 = tt.expand_dims %31 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %37 = tt.expand_dims %32 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> - %38 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #mma> - %39 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #blocked> - %40 = arith.muli %35, %38 : tensor<128x1xi32, #mma> - %41 = arith.muli %36, %39 : tensor<128x1xi32, #blocked> - %42 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %43 = tt.addptr %42, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %46 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %47 = tt.expand_dims %44 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> - %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %49 = tt.expand_dims %46 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %50 = tt.broadcast %43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %51 = tt.broadcast %47 : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> - %52 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> - %53 = tt.addptr %50, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %54 = tt.load %53 : tensor<128x64x!tt.ptr, #blocked> - %55 = tt.splat %12 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %56 = tt.addptr %55, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %57 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %58 = tt.addptr %57, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %59 = tt.load %58 : tensor<128x64x!tt.ptr, #blocked> - %60 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %61 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %62 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %63 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %64 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %65 = arith.addi %63, %60 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %66 = arith.addi %64, %62 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %67 = tt.expand_dims %65 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2> - %68 = tt.splat %arg12 : i32 -> tensor<1x16xi32, #blocked2> - %69 = arith.muli %67, %68 : tensor<1x16xi32, #blocked2> - %70 = tt.splat %10 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %71 = tt.addptr %70, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %73 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %74 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1xi32, #blocked2> - %75 = tt.expand_dims %73 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3> - %76 = tt.broadcast %71 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %77 = tt.broadcast %74 : tensor<64x1xi32, #blocked2> -> tensor<64x16xi32, #blocked2> - %78 = tt.addptr %76, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %79 = tt.expand_dims %66 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> - %80 = tt.splat %arg12 : i32 -> tensor<16x1xi32, #blocked1> - %81 = arith.muli %79, %80 : tensor<16x1xi32, #blocked1> - %82 = tt.splat %13 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked1> - %83 = tt.addptr %82, %81 : tensor<16x1x!tt.ptr, #blocked1>, tensor<16x1xi32, #blocked1> - %84 = tt.broadcast %83 : tensor<16x1x!tt.ptr, #blocked1> -> tensor<16x64x!tt.ptr, #blocked1> - %85 = tt.broadcast %49 : tensor<1x64xi32, #blocked1> -> tensor<16x64xi32, #blocked1> - %86 = tt.addptr %84, %85 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> - %87 = tt.splat %17 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %88 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> - %89 = tt.splat %18 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %90 = arith.muli %arg12, %c16_i32 : i32 - %91 = tt.splat %90 : i32 -> tensor<64x16xi32, #blocked2> - %92 = tt.splat %90 : i32 -> tensor<16x64xi32, #blocked1> - %93:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %cst_1, %arg18 = %19, %arg19 = %78, %arg20 = %86) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1>) : i32 { - %206 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> - %207 = tt.splat %arg18 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %208 = arith.addi %207, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %209 = tt.addptr %87, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %210 = tt.load %209 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %213 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %215 = tt.dot %212, %214, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> - %217 = tt.broadcast %216 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> - %218 = arith.subf %215, %217 : tensor<128x16xf32, #mma1> - %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> - %220 = tt.expand_dims %208 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> - %221 = tt.broadcast %220 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> - %222 = arith.cmpi sge, %221, %88 : tensor<128x16xi32, #mma1> - %223 = arith.select %222, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %224 = tt.load %arg20 : tensor<16x64x!tt.ptr, #blocked1> - %225 = arith.truncf %223 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %226 = triton_gpu.local_alloc %225 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %227 = triton_gpu.local_load %226 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %228 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %229 = triton_gpu.local_load %228 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %230 = tt.dot %227, %229, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %231 = tt.addptr %89, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %232 = tt.load %231 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %233 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %234 = tt.trans %233 {order = array} : !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %235 = triton_gpu.local_load %234 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %236 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %237 = triton_gpu.local_load %236 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %238 = tt.dot %237, %235, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %239 = tt.expand_dims %232 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> - %240 = tt.broadcast %239 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> - %241 = arith.subf %238, %240 : tensor<128x16xf32, #mma1> - %242 = arith.mulf %223, %241 : tensor<128x16xf32, #mma1> - %243 = arith.truncf %242 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %244 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> - %245 = tt.trans %244 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %246 = triton_gpu.local_load %245 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %247 = triton_gpu.local_alloc %243 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %248 = triton_gpu.local_load %247 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %249 = tt.dot %248, %246, %arg17 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %250 = arith.addi %arg18, %c16_i32 : i32 - %251 = tt.addptr %arg19, %91 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %252 = tt.addptr %arg20, %92 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> - scf.yield %230, %249, %250, %251, %252 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1> - } - %94 = arith.addi %19, %c128_i32 : i32 - %95 = arith.subi %arg14, %94 : i32 - %96 = arith.divsi %95, %c32_i32 : i32 - %97 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %98 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %99 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %100 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %101 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %102 = arith.addi %100, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %103 = arith.addi %101, %99 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %104 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> - %105 = tt.splat %arg12 : i32 -> tensor<1x32xi32, #blocked3> - %106 = arith.muli %104, %105 : tensor<1x32xi32, #blocked3> - %107 = tt.splat %10 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %108 = tt.addptr %107, %106 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %109 = tt.broadcast %108 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %110 = tt.broadcast %75 : tensor<64x1xi32, #blocked3> -> tensor<64x32xi32, #blocked3> - %111 = tt.addptr %109, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %112 = tt.expand_dims %103 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> - %113 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked> - %114 = arith.muli %112, %113 : tensor<32x1xi32, #blocked> - %115 = tt.splat %13 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> - %116 = tt.addptr %115, %114 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> - %117 = tt.broadcast %116 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> - %118 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %119 = tt.addptr %117, %118 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - %120 = tt.splat %17 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %121 = tt.splat %18 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %122 = arith.muli %arg12, %c32_i32 : i32 - %123 = tt.splat %122 : i32 -> tensor<64x32xi32, #blocked3> - %124 = tt.splat %122 : i32 -> tensor<32x64xi32, #blocked> - %125:5 = scf.for %arg15 = %c0_i32 to %96 step %c1_i32 iter_args(%arg16 = %93#0, %arg17 = %93#1, %arg18 = %94, %arg19 = %111, %arg20 = %119) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked>) : i32 { - %206 = tt.load %arg19 : tensor<64x32x!tt.ptr, #blocked3> - %207 = tt.splat %arg18 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %208 = arith.addi %207, %98 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %209 = tt.addptr %120, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %210 = tt.load %209 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %213 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %215 = tt.dot %212, %214, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> - %217 = tt.broadcast %216 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> - %218 = arith.subf %215, %217 : tensor<128x32xf32, #mma> - %219 = math.exp2 %218 : tensor<128x32xf32, #mma> - %220 = tt.load %arg20 : tensor<32x64x!tt.ptr, #blocked> - %221 = arith.truncf %219 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %222 = triton_gpu.convert_layout %221 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %223 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %224 = triton_gpu.local_load %223 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %225 = tt.dot %222, %224, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %226 = tt.addptr %121, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %227 = tt.load %226 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %228 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> - %229 = tt.trans %228 {order = array} : !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %230 = triton_gpu.local_load %229 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %231 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %232 = triton_gpu.local_load %231 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %233 = tt.dot %232, %230, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %234 = tt.expand_dims %227 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> - %235 = tt.broadcast %234 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> - %236 = arith.subf %233, %235 : tensor<128x32xf32, #mma> - %237 = arith.mulf %219, %236 : tensor<128x32xf32, #mma> - %238 = arith.truncf %237 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %239 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> - %240 = tt.trans %239 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %241 = triton_gpu.local_load %240 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %242 = triton_gpu.convert_layout %238 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %243 = tt.dot %242, %241, %arg17 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %244 = arith.addi %arg18, %c32_i32 : i32 - %245 = tt.addptr %arg19, %123 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %246 = tt.addptr %arg20, %124 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - scf.yield %225, %243, %244, %245, %246 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked> - } - %126 = tt.splat %16 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %127 = tt.addptr %126, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %128 = tt.broadcast %127 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %129 = tt.addptr %128, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %130 = arith.truncf %125#0 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %129, %130 : tensor<128x64x!tt.ptr, #mma> - %131 = tt.splat %arg3 : f32 -> tensor<128x64xf32, #mma> - %132 = arith.mulf %125#1, %131 : tensor<128x64xf32, #mma> - %133 = tt.splat %15 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %134 = tt.addptr %133, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %135 = tt.broadcast %134 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %136 = tt.addptr %135, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %137 = arith.truncf %132 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %136, %137 : tensor<128x64x!tt.ptr, #mma> - %138 = tt.splat %10 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %139 = tt.addptr %138, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %140 = tt.broadcast %139 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %141 = tt.addptr %140, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %142 = tt.load %141 : tensor<128x64x!tt.ptr, #blocked> - %143 = tt.splat %13 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %144 = tt.addptr %143, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %145 = tt.broadcast %144 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %146 = tt.addptr %145, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %147 = tt.load %146 : tensor<128x64x!tt.ptr, #blocked> - %148 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %149 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %150 = tt.addptr %148, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %151 = tt.addptr %149, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %152 = tt.load %150 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %153 = tt.load %151 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %154 = tt.expand_dims %152 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> - %155 = tt.expand_dims %153 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> - %156 = tt.splat %11 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %157 = tt.addptr %156, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %158 = tt.broadcast %157 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %159 = tt.addptr %158, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %160 = tt.splat %12 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %161 = tt.addptr %160, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %162 = tt.broadcast %161 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %163 = tt.addptr %162, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %164 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %165 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %166 = tt.addptr %164, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %167 = tt.addptr %165, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %168 = tt.load %166 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %169 = tt.load %167 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %170 = tt.broadcast %154 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> - %171 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> - %172 = tt.expand_dims %168 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> - %173 = tt.broadcast %172 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> - %174 = arith.muli %arg12, %c16_i32 : i32 - %175 = tt.splat %174 : i32 -> tensor<64x16xi32, #blocked2> - %176 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %177:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %19, %arg18 = %159, %arg19 = %163, %arg20 = %c-1_i32) -> (tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32) : i32 { - %206 = arith.addi %arg20, %c1_i32 : i32 - %207 = arith.cmpi slt, %206, %c1_i32 : i32 - %208 = arith.select %207, %206, %c0_i32 : i32 - %209 = tt.load %arg18 : tensor<64x16x!tt.ptr, #blocked2> - %210 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> - %211 = triton_gpu.memdesc_subview %176[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %210, %211 : tensor<64x16xf16, #blocked2> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %212 = triton_gpu.local_load %211 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %215 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %216 = triton_gpu.local_load %215 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %217 = tt.dot %214, %216, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %218 = arith.subf %217, %170 : tensor<128x16xf32, #mma1> - %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> - %220 = tt.splat %arg17 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %221 = arith.addi %220, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %222 = tt.expand_dims %221 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> - %223 = tt.broadcast %222 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> - %224 = arith.cmpi sge, %171, %223 : tensor<128x16xi32, #mma1> - %225 = arith.select %224, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %226 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %227 = triton_gpu.local_load %226 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %228 = tt.dot %227, %212, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %229 = arith.subf %228, %173 : tensor<128x16xf32, #mma1> - %230 = arith.mulf %225, %229 : tensor<128x16xf32, #mma1> - %231 = arith.truncf %230 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %232 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> - %233 = tt.trans %232 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %234 = triton_gpu.local_load %233 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %235 = triton_gpu.local_alloc %231 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %236 = triton_gpu.local_load %235 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %237 = tt.dot %236, %234, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %238 = arith.addi %arg17, %c16_i32 : i32 - %239 = tt.addptr %arg18, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %240 = tt.addptr %arg19, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - scf.yield %237, %238, %239, %240, %208 : tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32 - } - triton_gpu.local_dealloc %176 : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %178 = arith.divsi %19, %c32_i32 : i32 - %179 = arith.muli %178, %c32_i32 : i32 - %180 = arith.subi %19, %179 : i32 - %181 = tt.splat %180 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %182 = arith.addi %181, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %183 = tt.expand_dims %182 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> - %184 = arith.muli %183, %105 : tensor<1x32xi32, #blocked3> - %185 = tt.splat %11 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %186 = tt.addptr %185, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %187 = tt.broadcast %186 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %188 = tt.addptr %187, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %189 = tt.splat %12 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %190 = tt.addptr %189, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %191 = tt.broadcast %190 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %192 = tt.addptr %191, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %193 = tt.broadcast %155 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> - %194 = tt.expand_dims %169 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> - %195 = tt.broadcast %194 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> - %196 = arith.muli %arg12, %c32_i32 : i32 - %197 = tt.splat %196 : i32 -> tensor<64x32xi32, #blocked3> - %198 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %199:4 = scf.for %arg15 = %c0_i32 to %178 step %c1_i32 iter_args(%arg16 = %177#0, %arg17 = %188, %arg18 = %192, %arg19 = %c-1_i32) -> (tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32) : i32 { - %206 = arith.addi %arg19, %c1_i32 : i32 - %207 = arith.cmpi slt, %206, %c1_i32 : i32 - %208 = arith.select %207, %206, %c0_i32 : i32 - %209 = tt.load %arg17 : tensor<64x32x!tt.ptr, #blocked3> - %210 = tt.load %arg18 : tensor<64x32x!tt.ptr, #blocked3> - %211 = triton_gpu.memdesc_subview %198[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %210, %211 : tensor<64x32xf16, #blocked3> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %212 = triton_gpu.local_load %211 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %215 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %216 = triton_gpu.local_load %215 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %217 = tt.dot %214, %216, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %218 = arith.subf %217, %193 : tensor<128x32xf32, #mma> - %219 = math.exp2 %218 : tensor<128x32xf32, #mma> - %220 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %221 = triton_gpu.local_load %220 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %222 = tt.dot %221, %212, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %223 = arith.subf %222, %195 : tensor<128x32xf32, #mma> - %224 = arith.mulf %219, %223 : tensor<128x32xf32, #mma> - %225 = arith.truncf %224 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %226 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> - %227 = tt.trans %226 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %228 = triton_gpu.local_load %227 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %229 = triton_gpu.convert_layout %225 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %230 = tt.dot %229, %228, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %231 = tt.addptr %arg17, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %232 = tt.addptr %arg18, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - scf.yield %230, %231, %232, %208 : tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32 - } - triton_gpu.local_dealloc %198 : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %200 = tt.splat %14 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %201 = tt.addptr %200, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %202 = tt.broadcast %201 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %203 = tt.addptr %202, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %204 = arith.mulf %199#0, %cst_2 : tensor<128x64xf32, #mma> - %205 = arith.truncf %204 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %203, %205 : tensor<128x64x!tt.ptr, #mma> - tt.return - } -} - // ----- // CHECK-LABEL: sink_convert_dealloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir new file mode 100644 index 0000000000..5c173ffb48 --- /dev/null +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -0,0 +1,211 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s + +// Check the logic of sched-2nd-load optimizations +// + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> + +// Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough +// The following tile sizes should apply the optimization +// 256x256x128 +// 256x256x64 +// The following tile sizes should NOT apply the optimization +// 256x64x128 +// 256x256x32 +// + +// Should apply: tile size 256x256x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should apply: tile size 256x256x64 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x64 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 256x64x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x64x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x64xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 256x256x32 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x32 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// Category 2: single dot with two loads and tile size is large enough (128x128x128). +// We make sure the move is legal. +// Should NOT apply: the 2nd load has a user before the dot +// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.store +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> + %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> + tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> + %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> + triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<128x128xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Category 3: two dots in the for loop. Make sure the optimization is not applied +// should NOT apply: two dots +// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot +// CHECK: triton_gpu.local_load +// CHECK-NEXT: triton_gpu.local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.load +// CHECK-NEXT: tt.load +// CHECK-NEXT: triton_gpu.local_store +// CHECK-NEXT: triton_gpu.local_store +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 22349c50e3..9371c8b5f8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -20,13 +20,15 @@ namespace ttg = mlir::triton::gpu; // Return true if the given moduleOp contains a pure matmul problem; i.e., // single dot in the main loop. static bool isPureMatmulProblem(ModuleOp moduleOp) { - for (auto forOp : moduleOp.getOps()) { + bool isMatmul = true; + bool foundLoop = false; + moduleOp.walk([&](scf::ForOp forOp) -> void { int counter = 0; forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); - if (counter != 1) - return false; - } - return true; + isMatmul = (isMatmul && (counter == 1)); + foundLoop = true; + }); + return foundLoop && isMatmul; } // Search through block to find earliest insertion point for move op. This can @@ -267,6 +269,82 @@ static void scheduleGlobalLoadLocalStore(ModuleOp m) { } } +/** + * Sched-load optimization for matmul kernels with large tile sizes + * The basic idea of sched-load optimization is to sink the 2nd tt.load + * after local_load so that global_load instructions can be interleaved with + * mfma's. This can help hide the issue latency of global_load instructions + * and improve performance on MI300X. + * + * It's assumed that the IR before this optimization has the following + * structure: + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * tileB = tt.load b_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * After this optimization, the IR is transformed to + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * For now, we don't have a perfect hueristic about when should this + * optimization be applied. Therefore, we implement a simple hueristic that + * this is applied when the tile size of A and B are large enough, i.e. + * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical + * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We + * are experimenting how to better control instruction scheduling and enable + * such optimizations. + */ +static void sinkSecondLoad(ModuleOp m) { + m.walk([&](scf::ForOp forOp) -> void { + SetVector loadOps; + triton::DotOp dotOp; + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + loadOps.insert(loadOp); + if (auto curOp = dyn_cast(&op)) + dotOp = curOp; + } + // Only apply the optimization when there are 2 load's in the loop + if (loadOps.size() != 2) + return; + // Only apply the optimization when tile size is large enough + // 1. nonKDim >= 128 + // 2. kDim >= 64 + auto ldAOp = loadOps[0]; + auto tileAShape = cast(ldAOp.getType()).getShape(); + auto ldBOp = loadOps[1]; + auto tileBShape = cast(ldBOp.getType()).getShape(); + if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) + return; + // Only apply the optimization when the moving is legal + // 1. Make sure the 2nd loadOp is before the dot + // 2. Make sure the first user of the 2nd loadOp is after the dot. + bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); + auto firstUser = *ldBOp.getResult().getUsers().begin(); + bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); + if (isBeforeDotOp && firstUserAfterDotOp) + // move ldBOp right before tt.dot + ldBOp->moveBefore(dotOp); + }); +} + //===----------------------------------------------------------------------===// // Pass definition //===----------------------------------------------------------------------===// @@ -288,8 +366,10 @@ struct TritonAMDGPUReorderInstructionsPass moveUpTranspose(m); - if (isPureMatmulProblem(m)) + if (isPureMatmulProblem(m)) { scheduleGlobalLoadLocalStore(m); + sinkSecondLoad(m); + } } }; } // namespace