From 6d3ed0b91116e1e238a56f8b1d0d7cdaa2141911 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 27 Nov 2024 16:49:06 -0500 Subject: [PATCH] [DIALECT] Rename `triton_gpu` to `ttg` and `triton_nvidia_gpu` to `ttng` (#5266) It may cause changes for downstream tasks but we think it's beneficial to shorten dialect name and make them consistent. That is, we are using `tt` to represent the `triton` dialect. --- bin/triton-tensor-layout.cpp | 8 +- include/triton/Analysis/Allocation.h | 4 +- .../Conversion/TritonGPUToLLVM/Utility.h | 2 +- .../TritonToTritonGPU/TritonToTritonGPUPass.h | 8 +- .../Dialect/TritonGPU/IR/CMakeLists.txt | 8 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 6 +- .../Dialect/TritonGPU/IR/TritonGPUDialect.td | 18 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 10 +- .../Dialect/TritonNvidiaGPU/IR/CMakeLists.txt | 8 +- .../IR/TritonNvidiaGPUDialect.td | 16 +- .../TritonGPUToLLVM/AllocateSharedMemory.cpp | 2 +- .../TritonGPUToLLVM/ControlFlowOpToLLVM.cpp | 2 +- .../GlobalScratchMemoryAllocation.cpp | 22 +- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 31 +- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 4 +- python/test/unit/language/test_core.py | 68 +- python/test/unit/language/test_pipeliner.py | 22 +- python/test/unit/tools/test_aot.py | 2 +- python/test/unit/tools/test_irsource.py | 16 +- python/triton/compiler/compiler.py | 4 +- test/Analysis/test-alias.mlir | 108 +- test/Analysis/test-allocation.mlir | 350 ++-- test/Analysis/test-membar.mlir | 624 ++++---- test/Conversion/amd/buffer_load_store.mlir | 32 +- test/Conversion/amd/builtin_func_to_llvm.mlir | 4 +- test/Conversion/amd/compute-base-ptr.mlir | 14 +- ...ecompose-unsupported-conversions-cdna.mlir | 36 +- .../decompose-unsupported-conversions.mlir | 100 +- test/Conversion/amd/dedup-by-constancy.mlir | 12 +- test/Conversion/amd/fp_to_fp.mlir | 16 +- .../amd/invalid_extractslice_to_llvm.mlir | 24 +- test/Conversion/amd/load_store.mlir | 18 +- test/Conversion/amd/math-denorm-handling.mlir | 8 +- test/Conversion/amd/mfma-shortcut.mlir | 16 +- test/Conversion/amd/tritongpu_to_llvm.mlir | 52 +- .../amd/tritongpu_wmma_dot_to_llvm.mlir | 56 +- test/Conversion/dedup-by-constancy.mlir | 8 +- test/Conversion/divide-by-0.mlir | 8 +- test/Conversion/triton_to_tritongpu.mlir | 28 +- test/Conversion/tritongpu_to_llvm.mlir | 756 ++++----- .../tritongpu_to_llvm_block_dot_shortcut.mlir | 26 +- test/Conversion/tritongpu_to_llvm_hopper.mlir | 172 +- .../tritongpu_to_llvm_hopper_ptx80.mlir | 12 +- test/Conversion/tritongpu_to_llvm_volta.mlir | 4 +- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 56 +- test/Tools/tensor_layout_print.mlir | 10 +- test/Triton/canonicalize.mlir | 28 +- test/Triton/invalid.mlir | 74 +- test/Triton/reproducer.mlir | 2 +- test/Triton/vecadd.mlir | 134 +- test/TritonGPU/accelerate-matmul.mlir | 184 +-- test/TritonGPU/accumulator-init.mlir | 104 +- .../amd/accelerate-amd-matmul-mfma.mlir | 18 +- .../amd/accelerate-amd-matmul-wmma-gen1.mlir | 144 +- .../amd/accelerate-amd-matmul-wmma-gen2.mlir | 124 +- .../amd/amd-canonicalize-pointers.mlir | 112 +- .../TritonGPU/amd/amd-convert-buffer-ops.mlir | 20 +- test/TritonGPU/amd/amd-extractslice-op.mlir | 6 +- test/TritonGPU/amd/amd-instruction-sched.mlir | 8 +- test/TritonGPU/amd/amd-optimize-epilogue.mlir | 32 +- .../amd/amd-reorder-instructions.mlir | 334 ++-- test/TritonGPU/amd/amd-sched-2nd-load.mlir | 180 +-- test/TritonGPU/amd/optimize-lds-usage.mlir | 158 +- test/TritonGPU/canonicalize.mlir | 120 +- test/TritonGPU/coalesce.mlir | 52 +- test/TritonGPU/combine.mlir | 1426 ++++++++--------- test/TritonGPU/dot-operands.mlir | 260 +-- test/TritonGPU/fence-inserstion.mlir | 44 +- test/TritonGPU/global_scratch_alloc.mlir | 32 +- test/TritonGPU/invalid-attributes.mlir | 72 +- test/TritonGPU/invalid.mlir | 52 +- test/TritonGPU/loop-pipeline-cuda.mlir | 136 +- test/TritonGPU/loop-pipeline-hip.mlir | 150 +- test/TritonGPU/loop-pipeline-hopper.mlir | 686 ++++---- .../loop-pipeline-indirect-load.mlir | 56 +- test/TritonGPU/loop-pipeline.mlir | 940 +++++------ test/TritonGPU/loop-schedule.mlir | 22 +- test/TritonGPU/matmul-loop-pipeline.mlir | 6 +- test/TritonGPU/ops.mlir | 38 +- test/TritonGPU/optimize-locality.mlir | 342 ++-- test/TritonGPU/optimize_epilogue.mlir | 20 +- .../pipeline-hopper-remove-wait.mlir | 130 +- test/TritonGPU/prefetch.mlir | 164 +- test/TritonGPU/reduce-data-duplication.mlir | 36 +- test/TritonGPU/reorder-instructions.mlir | 106 +- test/TritonGPU/tritongpu_ops.mlir | 8 +- test/TritonGPU/verify-blocked-layout.mlir | 40 +- test/TritonNvidiaGPU/membar.mlir | 72 +- test/TritonNvidiaGPU/tma_lowering.mlir | 36 +- third_party/amd/backend/compiler.py | 2 +- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 6 +- .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 4 +- .../amd/python/test/test_extract_slice.py | 24 +- third_party/nvidia/backend/compiler.py | 8 +- unittest/Dialect/TritonGPU/DialectTest.cpp | 2 +- 96 files changed, 4781 insertions(+), 4786 deletions(-) diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index 4087ac1350..7c635dafaa 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -22,7 +22,7 @@ using namespace mlir; // clang-format off // Example usage: // -// triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" +// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" // // triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt // @@ -30,8 +30,8 @@ using namespace mlir; // // An input file usually looks like: // ''' -// #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> -// #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +// #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +// #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> // ''' // clang-format on @@ -83,7 +83,7 @@ LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace(); // Dispatch to the corresponding dialect helper function to print the layout. - if (dialectName == "triton_gpu") { + if (dialectName == "ttg") { os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); return success(); } diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 9d0d6684da..91bc895b20 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -180,8 +180,8 @@ class Allocation { private: /// A class that represents a shared memory buffer struct BufferT { - /// Explicit: triton_gpu.local_alloc - /// Scratch: triton_gpu.convert_layout + /// Explicit: ttg.local_alloc + /// Scratch: ttg.convert_layout /// Virtual: triton.call enum class BufferKind { Explicit, Scratch, Virtual }; diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index ba24461a1f..d9c3acbf71 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -400,7 +400,7 @@ inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, ModuleOp mod = funcOp.getOperation()->getParentOfType(); auto allocSizeAttr = mod.getOperation()->getAttrOfType( - "triton_gpu.global_scratch_memory_size"); + "ttg.global_scratch_memory_size"); if (!allocSizeAttr) { return gmemBase; } diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h index 78917fdfdd..ad8e640413 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -12,11 +12,11 @@ template class OperationPass; namespace triton { -constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps"; -constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas"; -constexpr static char AttrTargetName[] = "triton_gpu.target"; +constexpr static char AttrNumWarpsName[] = "ttg.num-warps"; +constexpr static char AttrNumCTAsName[] = "ttg.num-ctas"; +constexpr static char AttrTargetName[] = "ttg.target"; -constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp"; +constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp"; // Create the pass with numWarps passed from cl::opt. std::unique_ptr> createConvertTritonToTritonGPUPass(); diff --git a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt index 189f6d4307..a211c7bc87 100644 --- a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,12 +1,12 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu) -mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttg) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttg) add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonGPUTableGen) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index fee5e0afe3..b900c3d2e3 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -616,7 +616,7 @@ Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warp for -#triton_gpu.blocked_layout<{ +#ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} warpsPerCTA = {1, 2} @@ -642,7 +642,7 @@ Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warp [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] for -#triton_gpu.blocked_layout<{ +#ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} warpsPerCTA = {1, 2} @@ -672,7 +672,7 @@ CTA [1,0] CTA [1,1] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] for -#triton_gpu.blocked_layout<{ +#ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} warpsPerCTA = {1, 2} diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index 10f2c8c688..1ab22c1a02 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -4,7 +4,7 @@ include "mlir/IR/OpBase.td" def TritonGPU_Dialect : Dialect { - let name = "triton_gpu"; + let name = "ttg"; let cppNamespace = "::mlir::triton::gpu"; @@ -21,24 +21,24 @@ def TritonGPU_Dialect : Dialect { ]; let extraClassDeclaration = [{ - static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static std::string getNumWarpsAttrName() { return "ttg.num-warps"; } static int getNumWarps(ModuleOp mod) { - if (!mod->hasAttr("triton_gpu.num-warps")) + if (!mod->hasAttr("ttg.num-warps")) llvm::report_fatal_error( - "TritonGPU module should contain a triton_gpu.num-warps attribute"); - return cast(mod->getAttr("triton_gpu.num-warps")).getInt(); + "TritonGPU module should contain a ttg.num-warps attribute"); + return cast(mod->getAttr("ttg.num-warps")).getInt(); } static int getNumCTAs(ModuleOp mod) { - if (!mod->hasAttr("triton_gpu.num-ctas")) + if (!mod->hasAttr("ttg.num-ctas")) return 1; - return cast(mod->getAttr("triton_gpu.num-ctas")).getInt(); + return cast(mod->getAttr("ttg.num-ctas")).getInt(); } void registerTypes(); - static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; } + static std::string getThreadsPerWarpAttrName() { return "ttg.threads-per-warp"; } static int getThreadsPerWarp(ModuleOp mod) { - Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp"); + Attribute threadsPerWarp = mod->getDiscardableAttr("ttg.threads-per-warp"); if(!threadsPerWarp) { return 32; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index b747fddde6..9aa3e0b626 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -188,13 +188,13 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree" is printed as "". + // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; } @@ -215,7 +215,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { let arguments = ( ins TTG_MemDescType:$src, Variadic:$offsets); - // Use qualified() otherwise "!triton_gpu.memdesc" is printed as "". + // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; let results = (outs TTG_MemDescType:$result); @@ -262,7 +262,7 @@ def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods(nullptr)); }]>]; - // Use qualified() otherwise "!triton_gpu.memdesc" is printed as "". + // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; let results = (outs TT_Tensor:$result); @@ -277,7 +277,7 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods" is printed as "". + // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{ $src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst)) }]; diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt index b7ce83fe7e..45c70e15c2 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -1,12 +1,12 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttng) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttng) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu) -mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttng) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttng) add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc) add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonNvidiaGPUTableGen) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td index 67ece715d2..951409d63e 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -25,7 +25,7 @@ include "mlir/IR/OpBase.td" def TritonNvidiaGPU_Dialect : Dialect { - let name = "triton_nvidia_gpu"; + let name = "ttng"; let cppNamespace = "::mlir::triton::nvidia_gpu"; @@ -43,18 +43,18 @@ def TritonNvidiaGPU_Dialect : Dialect { ]; let extraClassDeclaration = [{ - static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static std::string getNumWarpsAttrName() { return "ttg.num-warps"; } static int getNumWarps(ModuleOp mod) { - if(!mod->hasAttr("triton_gpu.num-warps")) + if(!mod->hasAttr("ttg.num-warps")) llvm::report_fatal_error( - "TritonGPU module should contain a triton_gpu.num-warps attribute"); - return cast(mod->getAttr("triton_gpu.num-warps")).getInt(); + "TritonGPU module should contain a ttg.num-warps attribute"); + return cast(mod->getAttr("ttg.num-warps")).getInt(); } static int getNumCTAs(ModuleOp mod) { - if(!mod->hasAttr("triton_gpu.num-ctas")) + if(!mod->hasAttr("ttg.num-ctas")) llvm::report_fatal_error( - "TritonGPU module should contain a triton_gpu.num-ctas attribute"); - return cast(mod->getAttr("triton_gpu.num-ctas")).getInt(); + "TritonGPU module should contain a ttg.num-ctas attribute"); + return cast(mod->getAttr("ttg.num-ctas")).getInt(); } void registerTypes(); }]; diff --git a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp index aae9faf0ee..0115383947 100644 --- a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -44,7 +44,7 @@ struct AllocateSharedMemory IntegerAttr::get(IntegerType::get(ctx, 32), offset)); }); }); - mod->setAttr("triton_gpu.shared", + mod->setAttr("ttg.shared", mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), allocation.getSharedMemorySize())); } diff --git a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index b1ec521d5e..06e19029eb 100644 --- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -91,7 +91,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { } auto opOffsetAttr = caller->getAttrOfType( - "triton_gpu.global_scratch_memory_offset"); + "ttg.global_scratch_memory_offset"); Value opOffsetVal; if (opOffsetAttr) { auto opOffset = opOffsetAttr.getValue().getZExtValue(); diff --git a/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp b/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp index 318cb7524c..3fcaf4197c 100644 --- a/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp @@ -21,7 +21,7 @@ static void allocateGMem(Operation *parentOp, // Recursively visit any dependency functions parentOp->walk([&](triton::CallOp call) { auto callable = call.resolveCallable(); - if (!callable->hasAttr("triton_gpu.global_scratch_memory_size")) { + if (!callable->hasAttr("ttg.global_scratch_memory_size")) { auto inserted = callStack.insert(parentOp); assert(inserted && "call cycle detected"); allocateGMem(callable, callStack); @@ -46,9 +46,9 @@ static void allocateGMem(Operation *parentOp, } else if (auto callOp = dyn_cast(op)) { auto callable = callOp.resolveCallable(); auto nbytes_attr = callable->getAttrOfType( - "triton_gpu.global_scratch_memory_size"); + "ttg.global_scratch_memory_size"); auto align_attr = callable->getAttrOfType( - "triton_gpu.global_scratch_memory_alignment"); + "ttg.global_scratch_memory_alignment"); assert(nbytes_attr); assert(align_attr); @@ -57,16 +57,16 @@ static void allocateGMem(Operation *parentOp, } if (nbytes > 0) { offset = roundUp(offset, align); - op->setAttr("triton_gpu.global_scratch_memory_offset", + op->setAttr("ttg.global_scratch_memory_offset", builder.getI32IntegerAttr(offset)); offset += nbytes; largestAlignment = std::max(largestAlignment, align); } }); int32_t totalMemorySize = roundUp(offset, largestAlignment); - parentOp->setAttr("triton_gpu.global_scratch_memory_size", + parentOp->setAttr("ttg.global_scratch_memory_size", builder.getI32IntegerAttr(totalMemorySize)); - parentOp->setAttr("triton_gpu.global_scratch_memory_alignment", + parentOp->setAttr("ttg.global_scratch_memory_alignment", builder.getI32IntegerAttr(largestAlignment)); } @@ -86,14 +86,14 @@ class TritonGPUGlobalScratchAllocationPass if (func.getVisibility() == SymbolTable::Visibility::Public) { assert(!seenKernel); seenKernel = true; - auto size = func->getAttrOfType( - "triton_gpu.global_scratch_memory_size"); + auto size = + func->getAttrOfType("ttg.global_scratch_memory_size"); auto align = func->getAttrOfType( - "triton_gpu.global_scratch_memory_alignment"); + "ttg.global_scratch_memory_alignment"); assert(size); assert(align); - mod->setAttr("triton_gpu.global_scratch_memory_size", size); - mod->setAttr("triton_gpu.global_scratch_memory_alignment", align); + mod->setAttr("ttg.global_scratch_memory_size", size); + mod->setAttr("ttg.global_scratch_memory_alignment", align); } }); assert(seenKernel); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 27fd26800b..3488a68613 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -47,7 +47,7 @@ struct GlobalScratchAllocOpConversion Location loc = op.getLoc(); auto opOffsetAttr = op->getAttrOfType( - "triton_gpu.global_scratch_memory_offset"); + "ttg.global_scratch_memory_offset"); assert(opOffsetAttr); auto opOffset = opOffsetAttr.getValue().getZExtValue(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1017a36c28..0799bd6df1 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1173,24 +1173,22 @@ LogicalResult DotOperandEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned opIdx, Attribute parent, unsigned kWidth) { if (opIdx != 0 && opIdx != 1) { - return emitError() - << "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: " - << opIdx; + return emitError() << "ttg.dot_op opIdx paramenter can be 0 or 1, got: " + << opIdx; } if (!parent) { - return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; + return emitError() << "ttg.dot_op parent paramenter cannot be null"; } if (auto parentAttr = mlir::dyn_cast(parent)) { if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) - return emitError() << "triton_gpu.dot_op kWidth parameter can only be " + return emitError() << "ttg.dot_op kWidth parameter can only be " "non-zero for Ampere or Hopper MMA parent"; if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) - return emitError() - << "triton_gpu.dot_op kWidth parameter is mandatory for " - "Ampere or Hopper MMA parent"; + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "Ampere or Hopper MMA parent"; if (opIdx != 0 && parentAttr.isHopper()) return emitError() - << "triton_gpu.dot_op opIdx parameter must be 0 for " + << "ttg.dot_op opIdx parameter must be 0 for " "Hopper MMA parent, since Hopper WGMMA only allows first " "operand to be in registers"; return success(); @@ -1199,29 +1197,26 @@ LogicalResult DotOperandEncodingAttr::verify( if (auto parentAttr = mlir::dyn_cast(parent)) { if (kWidth != 16 && parentAttr.getVersion() == 1 || kWidth != 8 && parentAttr.getVersion() == 2) - return emitError() << "triton_gpu.dot_op kWidth parameter must be 16 for " + return emitError() << "ttg.dot_op kWidth parameter must be 16 for " "gfx11 and 8 for gfx12"; return success(); } if (auto parentAttr = mlir::dyn_cast(parent)) { if (kWidth == 0) - return emitError() - << "triton_gpu.dot_op kWidth parameter is mandatory for " - "MFMA parent"; + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "MFMA parent"; return success(); } if (auto parentAttr = mlir::dyn_cast(parent)) { if (kWidth != 0) - return emitError() - << "triton_gpu.dot_op kWidth parameter is not supported " - "when the parent is a blocked layout"; + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; return success(); } - return emitError() << "triton_gpu.dot_op unexpected parent layout: " - << parent; + return emitError() << "ttg.dot_op unexpected parent layout: " << parent; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 46a55d550d..9ad71270c2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -15,12 +15,12 @@ // // %a: tensor<128x32xf16, #enc> // %a_tmp = tensor.subview %a[0, 0] [128, 16] -// %a_prefetch = triton_gpu.local_load %a_tmp +// %a_prefetch = ttg.local_load %a_tmp // scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch) // { // %x = tt.dot %a_prefetch_arg, %b, %c // %a_tmp_rem = tensor.subview %a_buf[0, 16] [128, 16] -// %a_prefetch_next = triton_gpu.local_load %a_tmp_rem +// %a_prefetch_next = ttg.local_load %a_tmp_rem // ... // scf.yield %next_a, ..., %a_prefetch_next // } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7606896b37..9329ff311c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -58,7 +58,7 @@ def promotion_numpy_2_0(): # num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] num_ctas_list = [1] -GPU_DIALECT = "triton_gpu" +GPU_DIALECT = "ttg" if is_interpreter(): THREADS_PER_WARP = 1 elif is_hip(): @@ -1201,7 +1201,7 @@ def make_ptr_str(name, shape): return f"{name} + {' + '.join(offsets)}" -# TODO: handle `%4 = triton_gpu.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` +# TODO: handle `%4 = ttg.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>`` @pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] for d in ['int32', 'uint32', 'uint16']]) @@ -2584,16 +2584,16 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa ir = f""" #blocked = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> - %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> %3 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> - %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> - %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> %7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> %8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> @@ -2737,7 +2737,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) - %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> + %16 = ttg.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> %17 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<1x!tt.ptr<{ty}>, #one_d_layout> tt.store %17, %16 : tensor<1x!tt.ptr<{ty}>, #one_d_layout> tt.return @@ -2750,7 +2750,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov #blocked = {blocked} #src = {src_layout} #one_d_layout = {one_d_layout} - module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> @@ -2859,7 +2859,7 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path ir = f""" #dst = {dst_layout} #src = {src_layout} - module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> @@ -2928,7 +2928,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli tt.reduce.return %14 : i32""" ir = f""" #src = {src_layout} - module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> @@ -3774,9 +3774,9 @@ def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.co # value is in rowmajor. But MMAv3 requires its second operand is in colmajor # because transpose is not supported for MMAv3 with float32 input. if capability[0] >= 9: - assert re.search(r"triton_gpu.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None + assert re.search(r"ttg.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None else: - assert re.search(r"triton_gpu.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + assert re.search(r"ttg.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None @pytest.mark.interpreter @@ -4524,7 +4524,7 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con k = kernel[(1, )](input, actual, shape[0], shape[1]) assert k.asm['ttgir'].count( - 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + 'ttg.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) @@ -5335,28 +5335,28 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t """ conversion = f""" - %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> - %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !triton_gpu.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> - %16 = triton_gpu.local_load %15 : !triton_gpu.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xi32, #src> - %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !triton_gpu.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> - %18 = triton_gpu.local_load %17 : !triton_gpu.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xf16, #src> + %15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #ttg.shared_memory> + %16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #ttg.shared_memory> -> tensor<{M}x{N}xi32, #src> + %17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #ttg.shared_memory> + %18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #ttg.shared_memory> -> tensor<{M}x{N}xf16, #src> - %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> - %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + %12 = ttg.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ ir = layouts + f""" - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> @@ -5449,23 +5449,23 @@ def do_test(src_layout, dst_layout): """ ir = layouts + f""" - module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32}} {{ tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> - %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> - %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> tt.return diff --git a/python/test/unit/language/test_pipeliner.py b/python/test/unit/language/test_pipeliner.py index f2c79fcd5e..d4ac263af0 100644 --- a/python/test/unit/language/test_pipeliner.py +++ b/python/test/unit/language/test_pipeliner.py @@ -277,26 +277,26 @@ def test_pipeline_matmul(scale, device): if is_cuda(): ttgir = handler.asm["ttgir"] if use_tma: - assert ttgir.count("triton_nvidia_gpu.async_tma_copy_global_to_local") != 0, "async tma copy not found" + assert ttgir.count("ttng.async_tma_copy_global_to_local") != 0, "async tma copy not found" assert ttgir.count(f"num = {NUM_STAGES} : i32") == 0, "num_stages not match" # a_tma, b_tma, output_tma, barriar - assert ttgir.count("triton_gpu.local_alloc") == 4, "alloc number not match" - assert ttgir.count("triton_nvidia_gpu.barrier_expect") != 0, "barrier_expect not found" - assert ttgir.count("triton_nvidia_gpu.wait_barrier") != 0, "wait_barrier not found" - assert ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found" + assert ttgir.count("ttg.local_alloc") == 4, "alloc number not match" + assert ttgir.count("ttng.barrier_expect") != 0, "barrier_expect not found" + assert ttgir.count("ttng.wait_barrier") != 0, "wait_barrier not found" + assert ttgir.count("ttng.warp_group_dot") != 0, "warp_group_dot not found" else: # 1. check async - assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found" + assert ttgir.count("ttg.async_copy_global_to_local") != 0, "async copy not found" # 2. check number of stages assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match" # 3. check alloc - assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match" + assert ttgir.count("ttg.local_alloc") == 2, "alloc number not match" # 4. check dot cc = torch.cuda.get_device_capability() if cc[0] >= 9: - ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found" + ttgir.count("ttng.warp_group_dot") != 0, "warp_group_dot not found" else: - ttgir.count("triton_gpu.dot") != 0, "dot not found" + ttgir.count("ttg.dot") != 0, "dot not found" def test_pipeline_vecadd(device): @@ -315,11 +315,11 @@ def test_pipeline_vecadd(device): if is_cuda(): ttgir = handler.asm["ttgir"] # 1. check async - assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found" + assert ttgir.count("ttg.async_copy_global_to_local") != 0, "async copy not found" # 2. check number of stages assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match" # 3. check alloc - assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match" + assert ttgir.count("ttg.local_alloc") == 2, "alloc number not match" @pytest.mark.parametrize("ROW_COUNT", [0, 1, 2, 3]) diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 935c495fb3..d80c79cf61 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -426,7 +426,7 @@ def test_compile_link_autotune_matmul(): def test_ttgir_to_ptx(): src = """ -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { tt.return } diff --git a/python/test/unit/tools/test_irsource.py b/python/test/unit/tools/test_irsource.py index 0e7c67cb03..48ed90d8b2 100644 --- a/python/test/unit/tools/test_irsource.py +++ b/python/test/unit/tools/test_irsource.py @@ -18,12 +18,12 @@ def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: ''' sample_ttgir = r""" -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -54,8 +54,8 @@ def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: assert src.parse_options()['num_warps'] == 8 sample_ttgir_vector_add = r""" - #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index a76cb132ce..f70c46a9d4 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -124,8 +124,8 @@ def make_ir(self, options, codegen_fns, module_map, context): def parse_options(self): if self.ext == "ttgir": - num_warps = self.module.get_int_attr("triton_gpu.num-warps") - assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute" + num_warps = self.module.get_int_attr("ttg.num-warps") + assert num_warps is not None, "Unable to parse ttg.num-warps attribute" return {'num_warps': num_warps} return dict() diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index e67e55fb1c..adea6e3b66 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -1,15 +1,15 @@ // RUN: triton-opt %s --mlir-disable-threading -test-print-alias -split-input-file 2>&1 | FileCheck %s -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#B_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: matmul_loop // CHECK-NOT: -> @@ -26,9 +26,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> @@ -41,7 +41,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: %0 -> %0 - %cst2 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } @@ -49,40 +49,40 @@ tt.func @alloc(%A : !tt.ptr) { tt.func @alloc_init(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: %0 -> %0 - %cst1 = triton_gpu.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst1 = ttg.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> tt.return } // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: %0 -> %0 - %tensor = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: %1 -> %0 - %b = triton_gpu.memdesc_trans %tensor {order=array} : !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %b = ttg.memdesc_trans %tensor {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory, mutable> tt.return } // CHECK-LABEL: subview -tt.func @subview(%A : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { +tt.func @subview(%A : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory>) { %index = arith.constant 0 : i32 // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %1 -> %0 - %cst1 = triton_gpu.memdesc_subview %a[%index, %index, %index] : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.memdesc_subview %a[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } // CHECK-LABEL: if_alias tt.func @if_alias(%i1 : i1) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %2 -> %0,%1 - %cst2 = scf.if %i1 -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { - scf.yield %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = scf.if %i1 -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> { + scf.yield %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { - scf.yield %b : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -90,11 +90,11 @@ tt.func @if_alias(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: %2 -> %2 - %c = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %arg6 -> %0 // CHECK-NEXT: %arg7 -> %1 // CHECK-NEXT: %arg8 -> %2 @@ -102,8 +102,8 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a, %b_shared = %b, %c_shared = %c) -> - (!triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - scf.yield %b_shared, %a_shared, %a_shared : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -111,11 +111,11 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -123,14 +123,14 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { scf.if %i1 { %index = arith.constant 8 : i32 // CHECK-NEXT: %4 -> %0,%1 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.memdesc_subview %a_shared[%index, %index] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -138,11 +138,11 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -150,23 +150,23 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: %3#1 -> %1 // CHECK-NEXT: %3#2 -> %2,%6,%6 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { // CHECK-NEXT: %arg11 -> %2,%6,%6 // CHECK-NEXT: %4 -> %2,%6,%6 - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { // CHECK-NEXT: %5 -> %6,%6 - %c_shared_next_next = scf.if %i1 -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } - scf.yield %c_shared_next_next : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } - scf.yield %a_shared, %b_shared, %c_shared_next : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -175,29 +175,29 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) { %idx = arith.constant 0 : i32 // CHECK: %0 -> %0 - %cst = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %cst_0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %2 -> %0 - %0 = triton_gpu.memdesc_subview %cst[%idx, %idx] : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %0 = ttg.memdesc_subview %cst[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> gpu.barrier // CHECK-NEXT: %3 -> %3 - %cst_1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %5 -> %0,%1,%3 // CHECK-NEXT: %6 -> %0,%1,%3 // CHECK-NEXT: %7 -> %0,%1,%3 - cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) -^bb1(%1: index, %2: !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, %3: !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, %4: !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>): // 2 preds: ^bb0, ^bb2 + cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) +^bb1(%1: index, %2: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, %3: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, %4: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>): // 2 preds: ^bb0, ^bb2 %5 = arith.cmpi slt, %1, %arg1 : index cf.cond_br %5, ^bb2, ^bb3 ^bb2: // pred: ^bb1 gpu.barrier %8 = arith.addi %1, %arg2 : index - cf.br ^bb1(%8, %4, %2, %3 : index, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) + cf.br ^bb1(%8, %4, %2, %3 : index, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) ^bb3: // pred: ^bb1 gpu.barrier // CHECK-NEXT: %10 -> %0 - %9 = triton_gpu.memdesc_subview %0[%idx, %idx] : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %9 = ttg.memdesc_subview %0[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index fe4da43ca9..a4dfb20bcb 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -7,22 +7,22 @@ // CHECK-128: scratch offset = {{.*}}, size = 128 // CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> - -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#sliceAd0 = #ttg.slice<{dim = 0, parent = #AL}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#B_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: empty tt.func @empty(%A : !tt.ptr) { %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> tt.return // CHECK: size = 0 } @@ -44,10 +44,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> // CHECK: offset = 0, size = 4608 - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> // CHECK-NEXT: offset = 0, size = 4352 - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> @@ -72,17 +72,17 @@ tt.func @reusable(%A : !tt.ptr) { %b_ptr = tt.splat %A : !tt.ptr -> tensor<32x128x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 4608 - %a1 = triton_gpu.convert_layout %a1_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a1 = ttg.convert_layout %a1_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %a2_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 1088 - %a2 = triton_gpu.convert_layout %a2_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> + %a2 = ttg.convert_layout %a2_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> %a3_ = tt.load %a_ptr, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 4608 - %a3 = triton_gpu.convert_layout %a3_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a3 = ttg.convert_layout %a3_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %c = tt.dot %a1, %a2, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 1088 - %a4 = triton_gpu.convert_layout %a4_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> + %a4 = ttg.convert_layout %a4_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> %c1 = tt.dot %a3, %a4, %c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return // CHECK-NEXT: size = 4608 @@ -95,47 +95,47 @@ tt.func @reusable(%A : !tt.ptr) { // CHECK-LABEL: preallocate tt.func @preallocate(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 4096, size = 1024 - %b = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 1024 - %c = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst2 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 1024 - %cst4 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst4 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 6144, size = 2048 - %e = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %a : !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %e = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %a : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 2048 - %d = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %b : !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %d = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %b : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 10240, size = 2048 - %f = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst4 : !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %c : !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %f = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst4 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %c : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 2048 - %cst5 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst5 = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %g = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %e : !triton_gpu.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %g = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %e : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %h = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %d : !triton_gpu.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %h = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %d : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %i = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %f : !triton_gpu.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst5 : !triton_gpu.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %i = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %f : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst5 : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 12288 } @@ -145,11 +145,11 @@ tt.func @preallocate(%A : !tt.ptr) { tt.func @unused(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK-NEXT: offset = 0, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK: size = 1024 } @@ -158,33 +158,33 @@ tt.func @unused(%A : !tt.ptr) { // CHECK-LABEL: longlive tt.func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst2 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst4 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %b = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst5 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst5 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst6 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst6 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %c = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst3 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst4 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst4 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 1024 - %d = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %d = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 4096 } @@ -193,43 +193,43 @@ tt.func @longlive(%A : !tt.ptr) { // CHECK-LABEL: multi_color tt.func @multi_color(%A : !tt.ptr) { // CHECK: offset = 0, size = 64 - %cst = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1536, size = 32 - %cst_0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1664, size = 128 - %cst_1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 - %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> - %1 = triton_gpu.local_load %cst : !triton_gpu.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = ttg.local_load %cst : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 0, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %2 = triton_gpu.local_load %cst_0 : !triton_gpu.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> + %cst_3 = ttg.local_alloc : () -> !ttg.memdesc<4x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %2 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %3 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %3 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> // CHECK-NEXT: offset = 0, size = 256 - %cst_4 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_4 = ttg.local_alloc : () -> !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 256, size = 64 - %cst_5 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %4 = triton_gpu.local_load %cst_5 : !triton_gpu.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> - %5 = triton_gpu.local_load %cst_5 : !triton_gpu.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> + %cst_5 = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> + %4 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> + %5 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_6 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<8x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_6 = ttg.local_alloc : () -> !ttg.memdesc<8x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1792, size = 128 - %cst_7 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %6 = triton_gpu.local_load %cst_0 : !triton_gpu.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> + %cst_7 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %6 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_8 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_8 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 256, size = 32 - %cst_9 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_9 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst_10 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %7 = triton_gpu.local_load %cst_1 : !triton_gpu.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x4xf16, #AL> - %8 = triton_gpu.local_load %cst_4 : !triton_gpu.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x32xf16, #AL> + %cst_10 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %7 = ttg.local_load %cst_1 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL> + %8 = ttg.local_load %cst_4 : !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x32xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %9 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %9 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> %cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL> - %10 = triton_gpu.local_load %cst_7 : !triton_gpu.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<2x32xf16, #AL> + %10 = ttg.local_load %cst_7 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL> %cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL> %cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL> // CHECK-NEXT: size = 1920 @@ -240,25 +240,25 @@ tt.func @multi_color(%A : !tt.ptr) { // CHECK-LABEL: multi_color_multi_rounds tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK: offset = 0, size = 32 - %cst = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1280, size = 128 - %cst_0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 8192 - %cst_1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 - %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> - %1 = triton_gpu.local_load %cst : !triton_gpu.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1152, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %2 = triton_gpu.local_load %cst : !triton_gpu.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> + %cst_3 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %2 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst_4 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %3 = triton_gpu.local_load %cst_0 : !triton_gpu.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x4xf16, #AL> - %4 = triton_gpu.local_load %cst_1 : !triton_gpu.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<1024x4xf16, #AL> + %cst_4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %3 = ttg.local_load %cst_0 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL> + %4 = ttg.local_load %cst_1 : !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<1024x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %5 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> - %6 = triton_gpu.local_load %cst_3 : !triton_gpu.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<2x32xf16, #AL> + %5 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %6 = ttg.local_load %cst_3 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL> // CHECK-NEXT: size = 10240 tt.return } @@ -267,10 +267,10 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } @@ -279,10 +279,10 @@ tt.func @alloc(%A : !tt.ptr) { // CHECK-LABEL: dealloc tt.func @dealloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: offset = 1024, size = 1024 - %cst1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 2048 } @@ -303,8 +303,8 @@ tt.func @scratch() { // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %tensor = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %b = triton_gpu.memdesc_trans %tensor {order=array} : !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %b = ttg.memdesc_trans %tensor {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory, mutable> tt.return } @@ -312,9 +312,9 @@ tt.func @trans(%A : !tt.ptr) { // CHECK-LABEL: extract_slice tt.func @extract_slice(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %index = arith.constant 0 : i32 - %cst1 = triton_gpu.memdesc_subview %cst0[%index, %index, %index] : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.memdesc_subview %cst0[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } @@ -326,9 +326,9 @@ tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { // CHECK: size = 8196 %c0_i32 = arith.constant 0 : i32 %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 - %3 = triton_gpu.local_load %2 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return %4 : i32 } @@ -338,9 +338,9 @@ tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { // CHECK: size = 8192 %c0_i32 = arith.constant 0 : i32 %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 - %3 = triton_gpu.local_load %2 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -349,25 +349,25 @@ tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { // CHECK-LABEL: if tt.func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst2 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst3 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 3072 } @@ -377,28 +377,28 @@ tt.func @if(%i1 : i1) { // CHECK-LABEL: if_else tt.func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 4096, size = 1024 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst2 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst3 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 5120 } @@ -408,13 +408,13 @@ tt.func @if_else(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - scf.yield %b_shared, %a_shared, %a_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 24576 @@ -423,18 +423,18 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if_slice tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { scf.if %i1 { %index = arith.constant 8 : i32 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.memdesc_subview %a_shared[%index, %index] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 24576 @@ -444,16 +444,16 @@ tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr // CHECK-LABEL: for_use_ancestor tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - %c0 = triton_gpu.memdesc_trans %c_shared_init {order=array} : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32x128xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c0 = ttg.memdesc_trans %c_shared_init {order=array} : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #A_SHARED_T, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 24576, size = 8192 - %c1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %b_shared, %a_shared: !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %b_shared, %a_shared: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 32768 @@ -464,40 +464,40 @@ tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - %c_shared_next_next = scf.if %i1 -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> { // CHECK-NEXT: offset = 24576, size = 8192 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { // CHECK-NEXT: offset = 32768, size = 8192 - %cst1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %cst1 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst1 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } - scf.yield %c_shared_next_next : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } - scf.yield %a_shared, %b_shared, %c_shared_next : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } // CHECK-NEXT: offset = 0, size = 8192 - %cst2 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 40960 } } -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-warps" = 4 : i32} { // CHECK-LABEL: alloc1 tt.func @alloc1(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } @@ -505,7 +505,7 @@ tt.func @alloc1(%A : !tt.ptr) { // CHECK-LABEL: alloc2 tt.func @alloc2(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 1024 } @@ -514,10 +514,10 @@ tt.func @alloc2(%A : !tt.ptr) { tt.func @alloc3(%cond : i1) { scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { // CHECK-NEXT: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 1024 @@ -539,7 +539,7 @@ tt.func @alloc4(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: single_call tt.func @single_call(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () @@ -550,7 +550,7 @@ tt.func @single_call(%A : !tt.ptr) { // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> @@ -565,9 +565,9 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 1024 - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () } else { @@ -582,7 +582,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -598,7 +598,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc3(%cond) : (i1) -> () tt.return @@ -608,7 +608,7 @@ tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_2 tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc4(%A, %cond) : (!tt.ptr, i1) -> () tt.return diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index a2711ba98f..29e0b253be 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -1,15 +1,15 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading --convert-scf-to-cf --allocate-shared-memory -test-print-membar 2>&1 | FileCheck %s -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#sliceAd0 = #ttg.slice<{dim = 0, parent = #AL}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: matmul_loop // There shouldn't be any membar with the dot op encoding. @@ -28,9 +28,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> @@ -46,10 +46,10 @@ tt.func @raw_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -59,14 +59,14 @@ tt.func @war_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - // CHECK: triton_gpu.local_alloc + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: ttg.local_alloc // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier - // CHECK-NEXT: %4 = triton_gpu.local_alloc - %4 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: %4 = ttg.local_alloc + %4 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> tt.return } @@ -76,25 +76,25 @@ tt.func @war_single_block_local_store(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - // CHECK: triton_gpu.local_alloc + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_alloc // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_store - triton_gpu.local_store %1, %2 : tensor<128x32xf16, #AL> -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttg.local_store + ttg.local_store %1, %2 : tensor<128x32xf16, #AL> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } // CHECK-LABEL: scratch tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load + // CHECK-NEXT: ttg.local_load // CHECK: gpu.barrier // CHECK: tt.reduce - %1 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> %2 = "tt.reduce" (%1) ({ ^bb0(%arg1: f16, %arg2: f16): %add = arith.addf %arg1, %arg2 : f16 @@ -105,34 +105,34 @@ tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { // CHECK-LABEL: async_wait tt.func @async_wait(%arg: tensor<32x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> - // CHECK: triton_gpu.async_wait - triton_gpu.async_wait {num = 4 : i32} + %cst0 = ttg.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: ttg.async_wait + ttg.async_wait {num = 4 : i32} // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<32x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<32x16xf16, #AL> tt.return } // CHECK-LABEL: subview tt.func @subview() { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> - %a = triton_gpu.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = ttg.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> %index = arith.constant 0 : i32 - %0 = triton_gpu.memdesc_subview %a[%index, %index] : !triton_gpu.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %0 = ttg.memdesc_subview %a[%index, %index] : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> tt.return } // CHECK-LABEL: trans -tt.func @trans(%a: !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { +tt.func @trans(%a: !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK-NOT: gpu.barrier - %b = triton_gpu.memdesc_trans %a {order=array} : !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> + %b = ttg.memdesc_trans %a {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory> tt.return } @@ -142,31 +142,31 @@ tt.func @async_copy_global_to_local(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.splat %A : !tt.ptr -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : i1 -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %alloc = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %subview = triton_gpu.memdesc_subview %alloc[%index, %index, %index] : !triton_gpu.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %subview = ttg.memdesc_subview %alloc[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %subview : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %4 = ttg.local_load %subview : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // If branch inserted a barrier for %cst0, but else didn't, then the barrier should be inserted in the parent region // CHECK-LABEL: multi_blocks tt.func @multi_blocks(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> scf.if %i1 { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> scf.yield } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -174,21 +174,21 @@ tt.func @multi_blocks(%i1 : i1) { // CHECK-LABEL: multi_blocks_join_barrier tt.func @multi_blocks_join_barrier(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> scf.if %i1 { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK-NOT: gpu.barrier // CHECK: tt.return - %a_ = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %a_ = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -196,25 +196,25 @@ tt.func @multi_blocks_join_barrier(%i1 : i1) { // CHECK-LABEL: multi_blocks_yield tt.func @multi_blocks_yield(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - %a = scf.if %i1 -> (!triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %a = scf.if %i1 -> (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %1 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> } else { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %3 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %3 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> } - %a_ = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - // CHECK: triton_gpu.local_load + %a_ = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK: ttg.local_load // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %4 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -222,27 +222,27 @@ tt.func @multi_blocks_yield(%i1 : i1) { // CHECK-LABEL: multi_blocks_entry_no_shared tt.func @multi_blocks_entry_no_shared(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - %a = scf.if %i1 -> (!triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %a = scf.if %i1 -> (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc + // CHECK-NEXT: ttg.local_alloc // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load + // CHECK-NEXT: ttg.local_load // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - %0 = triton_gpu.local_load %cst1 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %1 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %0 = ttg.local_load %cst1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> } else { // CHECK-NOT: gpu.barrier - // CHECK: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %cst1 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK: ttg.local_alloc + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -250,16 +250,16 @@ tt.func @multi_blocks_entry_no_shared(%i1 : i1) { // CHECK-LABEL: multi_blocks_noelse tt.func @multi_blocks_noelse(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> scf.if %i1 { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -267,39 +267,39 @@ tt.func @multi_blocks_noelse(%i1 : i1) { // CHECK-LABEL: multi_blocks_nested_scf tt.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> scf.if %i1 { scf.if %i2 { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } scf.yield } else { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - scf.yield %b_shared, %a_shared, %a_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } tt.return } @@ -309,24 +309,24 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_alias tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a1 = triton_gpu.local_load %a_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - scf.yield %c_shared, %a_shared, %b_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %a1 = ttg.local_load %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -335,63 +335,63 @@ tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % // CHECK-LABEL: for_reuse tt.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %a1 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %1 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %c_shared, %a_shared, %b_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %a2 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %2 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for_reuse_nested tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + // CHECK-NEXT: ttg.local_alloc + %a1 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %1 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %a2 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %2 = ttg.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } - scf.yield %c_shared, %a_shared, %b_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -399,25 +399,25 @@ tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { - %c_shared_next_next = scf.if %i1 -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } else { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } - scf.yield %c_shared_next_next : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } tt.return } @@ -426,30 +426,30 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: for_if_for tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - %c_blocked = triton_gpu.local_load %c_shared_init : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %c_blocked = ttg.local_load %c_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { - %c_shared_next_next = scf.if %i1 -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %cst0 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } else { - %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %c_blocked_next = triton_gpu.local_load %c_shared_next : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - scf.yield %c_shared : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %c_blocked_next = ttg.local_load %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } - scf.yield %c_shared_ : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared_ : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } // CHECK-NOT: gpu.barrier - %b_blocked_next = triton_gpu.local_load %b_shared: !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - scf.yield %a_shared, %b_shared, %c_shared_next_next : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_blocked_next = ttg.local_load %b_shared: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %a_shared, %b_shared, %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } tt.return } @@ -457,65 +457,65 @@ tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: cf_if tt.func @cf_if(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: cf_if_else tt.func @cf_if_else(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - cf.br ^bb3(%1 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.br ^bb3(%1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) ^bb2: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - cf.br ^bb3(%3 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) -^bb3(%arg: !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>): // 2 preds: ^bb1, ^bb2 + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.br ^bb3(%3 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) +^bb3(%arg: !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>): // 2 preds: ^bb1, ^bb2 cf.br ^bb4 ^bb4: // pred: ^bb3 - // CHECK: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK: ttg.local_load + %4 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %5 = triton_gpu.local_load %arg : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %5 = ttg.local_load %arg : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: cf_if_else_return tt.func @cf_if_else_return(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - %b = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %b = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_load %b : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_load %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return ^bb2: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_load %b : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_load %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -524,9 +524,9 @@ tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { // CHECK-NOT: gpu.barrier %c0_i32 = arith.constant 0 : i32 %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 - %3 = triton_gpu.local_load %2 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return %4 : i32 } @@ -534,53 +534,53 @@ tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !triton_gpu.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } } -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: convert_layout1 tt.func @convert_layout1(%A : !tt.ptr) { // CHECK-NOT: gpu.barrier - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout2 tt.func @convert_layout2(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - // CHECK: triton_gpu.local_load + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load // CHECK-NEXT: gpu.barrier - // CHECK: triton_gpu.local_load - %3 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> - %4 = triton_gpu.local_load %1 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> + // CHECK: ttg.local_load + %3 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %4 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout3 tt.func @convert_layout3(%cond : i1) { scf.if %cond { - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - // CHECK: triton_gpu.local_load + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load // CHECK-NOT: gpu.barrier - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x64xf16, #AL> + %1 = ttg.local_load %0 : !ttg.memdesc<16x64xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x64xf16, #AL> } else { - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - // CHECK: triton_gpu.local_load + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttg.local_alloc + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -602,7 +602,7 @@ tt.func @single_call_sync(%A : !tt.ptr) { // CHECK: tt.call // CHECK-NEXT: gpu.barrier tt.call @convert_layout1(%A) : (!tt.ptr) -> () - %1 = triton_gpu.convert_layout %0 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = ttg.convert_layout %0 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> tt.return } @@ -612,14 +612,14 @@ tt.func @single_call_no_sync(%A : !tt.ptr) { // CHECK-NOT: gpu.barrier %0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> tt.call @convert_layout2(%A) : (!tt.ptr) -> () - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #BL> + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #BL> tt.return } // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> tt.call @convert_layout1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> tt.call @convert_layout2(%A) : (!tt.ptr) -> () @@ -631,12 +631,12 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { scf.if %cond { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %cst_ = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call // CHECK-NEXT: gpu.barrier tt.call @convert_layout1(%A) : (!tt.ptr) -> () - %cst1 = triton_gpu.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !triton_gpu.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst1 = ttg.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory> } else { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK: tt.call @@ -649,7 +649,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -665,7 +665,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call tt.call @convert_layout3(%cond) : (i1) -> () tt.return @@ -677,7 +677,7 @@ tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { tt.call @convert_layout4(%A, %cond) : (!tt.ptr, i1) -> () // CHECK: tt.call // CHECK-NEXT: gpu.barrier - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !triton_gpu.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> tt.return } @@ -685,28 +685,28 @@ tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { tt.func public @kernel(%arg3: !tt.ptr, %arg4: !tt.ptr, %arg12: tensor<32x128xf16, #blocked>, %arg13: tensor<32x128xf32, #blocked>, %arg14: tensor<32x32xf16, #blocked1>) { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #blocked> - %37 = triton_gpu.local_alloc %arg14 {allocation.offset = 0 : i32} : (tensor<32x32xf16, #blocked1>) -> !triton_gpu.memdesc<32x32xf16, #shared, #triton_gpu.shared_memory> - %58 = triton_gpu.local_alloc %arg12 : (tensor<32x128xf16, #blocked>) -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory> + %37 = ttg.local_alloc %arg14 {allocation.offset = 0 : i32} : (tensor<32x32xf16, #blocked1>) -> !ttg.memdesc<32x32xf16, #shared, #ttg.shared_memory> + %58 = ttg.local_alloc %arg12 : (tensor<32x128xf16, #blocked>) -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> cf.br ^bb1 ^bb1: // 2 preds: ^bb0, ^bb1 %59 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 %60 = arith.cmpi eq, %59, %c0_i32 : i32 cf.cond_br %60, ^bb1, ^bb2 ^bb2: // pred: ^bb1 - %72 = triton_gpu.convert_layout %arg13 : tensor<32x128xf32, #blocked> -> tensor<32x128xf32, #mma> - %73 = triton_gpu.local_load %37 : !triton_gpu.memdesc<32x32xf16, #shared, #triton_gpu.shared_memory> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %74 = triton_gpu.local_load %58 : !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %75 = tt.dot %73, %74, %72, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x128xf32, #mma> - %76 = triton_gpu.convert_layout %75 {allocation.offset = 0 : i32} : tensor<32x128xf32, #mma> -> tensor<32x128xf32, #blocked> + %72 = ttg.convert_layout %arg13 : tensor<32x128xf32, #blocked> -> tensor<32x128xf32, #mma> + %73 = ttg.local_load %37 : !ttg.memdesc<32x32xf16, #shared, #ttg.shared_memory> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %74 = ttg.local_load %58 : !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.dot %73, %74, %72, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x128xf32, #mma> + %76 = ttg.convert_layout %75 {allocation.offset = 0 : i32} : tensor<32x128xf32, #mma> -> tensor<32x128xf32, #blocked> %77 = arith.truncf %76 : tensor<32x128xf32, #blocked> to tensor<32x128xf16, #blocked> %78 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> tt.store %78, %77 : tensor<32x128x!tt.ptr, #blocked> @@ -716,54 +716,54 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { // CHECK-LABEL: tma_special_cases tt.func @tma_special_cases(%arg1: !tt.ptr) -> (tensor<256x64xf16, #blocked>){ %true = arith.constant 1 : i1 %c0 = arith.constant 0 : i32 - %barrier = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> - %alloc = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - // CHECK: triton_nvidia_gpu.init_barrier - // CHECK-NEXT: triton_nvidia_gpu.init_barrier - triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + // CHECK: ttng.init_barrier + // CHECK-NEXT: ttng.init_barrier + ttng.init_barrier %barrier, 1 : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.init_barrier %barrier, 1 : <1xi64, #shared1, #ttg.shared_memory, mutable> // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_nvidia_gpu.barrier_expect - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local - // CHECK-NEXT: triton_nvidia_gpu.wait_barrier - triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local - // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.wait_barrier + ttng.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #ttg.shared_memory, mutable> -> <256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : <1xi64, #shared1, #ttg.shared_memory, mutable> + + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.barrier_expect // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_nvidia_gpu.wait_barrier - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttng.wait_barrier + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #ttg.shared_memory, mutable> -> <256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : <1xi64, #shared1, #ttg.shared_memory, mutable> - // CHECK-NEXT: triton_gpu.local_load - %t = triton_gpu.local_load %alloc : !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + // CHECK-NEXT: ttg.local_load + %t = ttg.local_load %alloc : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #blocked> - // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: ttng.barrier_expect // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local - // CHECK-NEXT: triton_nvidia_gpu.wait_barrier - triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.wait_barrier + ttng.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #ttg.shared_memory, mutable> -> <256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : <1xi64, #shared1, #ttg.shared_memory, mutable> // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_nvidia_gpu.inval_barrier - // CHECK-NEXT: triton_nvidia_gpu.inval_barrier - triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttng.inval_barrier + // CHECK-NEXT: ttng.inval_barrier + ttng.inval_barrier %barrier : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.inval_barrier %barrier : <1xi64, #shared1, #ttg.shared_memory, mutable> tt.return %t : tensor<256x64xf16, #blocked> } @@ -771,38 +771,38 @@ tt.func @tma_special_cases(%arg1: !tt.ptr) -> (tensor<256x64xf16, #blocke // ----- -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { // CHECK-LABEL: tma_special_cases_cf tt.func @tma_special_cases_cf(%arg1: !tt.ptr, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){ %true = arith.constant 1 : i1 %c0 = arith.constant 0 : i32 - %barrier = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> - %alloc = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> // CHECK: cf.cond_br scf.if %i1 { // CHECK-NOT: gpu.barrier - // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local - // CHECK-NEXT: triton_nvidia_gpu.barrier_expect - // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + // CHECK: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: ttng.wait_barrier // CHECK-NEXT: cf.br - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #ttg.shared_memory, mutable> -> <256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : <1xi64, #shared1, #ttg.shared_memory, mutable> scf.yield } else { // CHECK-NOT: gpu.barrier - // CHECK: triton_gpu.local_store + // CHECK: ttg.local_store // CHECK-NEXT: cf.br - triton_gpu.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + ttg.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> scf.yield } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %t = triton_gpu.local_load %alloc : !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + // CHECK-NEXT: ttg.local_load + %t = ttg.local_load %alloc : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #blocked> tt.return %t : tensor<256x64xf16, #blocked> } } diff --git a/test/Conversion/amd/buffer_load_store.mlir b/test/Conversion/amd/buffer_load_store.mlir index 209c7065d8..70abc55594 100644 --- a/test/Conversion/amd/buffer_load_store.mlir +++ b/test/Conversion/amd/buffer_load_store.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load tt.func @buffer_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 @@ -14,8 +14,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_mask tt.func @buffer_load_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { %c256_i32 = arith.constant 256 : i32 @@ -36,8 +36,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_mask_other tt.func @buffer_load_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { %c256_i32 = arith.constant 256 : i32 @@ -60,8 +60,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_store tt.func @buffer_store(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 @@ -74,8 +74,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_store_mask tt.func @buffer_store_mask(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { %c256_i32 = arith.constant 256 : i32 @@ -97,8 +97,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_store_vec4 tt.func @buffer_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -123,8 +123,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_store_vec1 tt.func @buffer_load_store_vec1(%arg0: !tt.ptr , %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -151,8 +151,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_store_vec2 tt.func @buffer_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr{tt.divisibility = 4 : i32}, %arg2: !tt.ptr{tt.divisibility = 4: i32}, %arg3: i32{tt.divisibility = 4: i32}) { %c256_i32 = arith.constant 256 : i32 diff --git a/test/Conversion/amd/builtin_func_to_llvm.mlir b/test/Conversion/amd/builtin_func_to_llvm.mlir index 06ef06c542..6458817302 100644 --- a/test/Conversion/amd/builtin_func_to_llvm.mlir +++ b/test/Conversion/amd/builtin_func_to_llvm.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm="ftz=True" | FileCheck %s --check-prefix=LLVM_FTZ // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm="ftz=False" | FileCheck %s --check-prefix=LLVM_NO_FTZ -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @test_fast_expf(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { // LLVM_FTZ: llvm.amdgcn.exp2.f32 // LLVM_NO_FTZ: llvm.exp2.f32 diff --git a/test/Conversion/amd/compute-base-ptr.mlir b/test/Conversion/amd/compute-base-ptr.mlir index c62f7bfb6c..84b0ffce2e 100644 --- a/test/Conversion/amd/compute-base-ptr.mlir +++ b/test/Conversion/amd/compute-base-ptr.mlir @@ -1,16 +1,16 @@ // RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 544 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}> +#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @local_load_offset tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) { - %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1) - %1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> loc(#loc2) + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1) + %1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> loc(#loc2) // This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type. // CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0 - %2 = triton_gpu.local_load %1 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3) + %2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3) tt.return } } diff --git a/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir b/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir index f30e0aa6d9..848e13118e 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir @@ -1,33 +1,33 @@ // RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s -// CHECK-DAG: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK-DAG: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> -// CHECK-DAG: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK-DAG: #[[DST_ENC:.+]] = #ttg.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #ttg.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #ttg.amd_mfma<{{.*}}> // CHECK: large_tensor_conversion -#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = false}> -#dst = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#src = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = false}> +#dst = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func @large_tensor_conversion(%arg0: tensor<128x128xf32, #src>) { - // CHECK: %[[TMP:.*]] = triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #[[SRC_ENC]]> -> tensor<128x128xf32, #[[TMP_ENC]]> - // CHECK: {{.*}} = triton_gpu.convert_layout %[[TMP]] : tensor<128x128xf32, #[[TMP_ENC]]> -> tensor<128x128xf32, #[[DST_ENC]]> - %0 = triton_gpu.convert_layout %arg0 : tensor<128x128xf32, #src> -> tensor<128x128xf32, #dst> + // CHECK: %[[TMP:.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, #[[SRC_ENC]]> -> tensor<128x128xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = ttg.convert_layout %[[TMP]] : tensor<128x128xf32, #[[TMP_ENC]]> -> tensor<128x128xf32, #[[DST_ENC]]> + %0 = ttg.convert_layout %arg0 : tensor<128x128xf32, #src> -> tensor<128x128xf32, #dst> tt.return } } // ----- -// CHECK-DAG: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK-DAG: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> -// CHECK-DAG: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK-DAG: #[[DST_ENC:.+]] = #ttg.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #ttg.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #ttg.amd_mfma<{{.*}}> // CHECK: large_tensor_3d_conversion -#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 1, 2], instrShape = [32, 32], isTransposed = false}> -#dst = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 64, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#src = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 1, 2], instrShape = [32, 32], isTransposed = false}> +#dst = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 64, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func @large_tensor_3d_conversion(%arg0: tensor<2x128x64xf32, #src>) { - // CHECK: %[[TMP:.*]] = triton_gpu.convert_layout {{.*}} : tensor<2x128x64xf32, #[[SRC_ENC]]> -> tensor<2x128x64xf32, #[[TMP_ENC]]> - // CHECK: {{.*}} = triton_gpu.convert_layout %[[TMP]] : tensor<2x128x64xf32, #[[TMP_ENC]]> -> tensor<2x128x64xf32, #[[DST_ENC]]> - %0 = triton_gpu.convert_layout %arg0 : tensor<2x128x64xf32, #src> -> tensor<2x128x64xf32, #dst> + // CHECK: %[[TMP:.*]] = ttg.convert_layout {{.*}} : tensor<2x128x64xf32, #[[SRC_ENC]]> -> tensor<2x128x64xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = ttg.convert_layout %[[TMP]] : tensor<2x128x64xf32, #[[TMP_ENC]]> -> tensor<2x128x64xf32, #[[DST_ENC]]> + %0 = ttg.convert_layout %arg0 : tensor<2x128x64xf32, #src> -> tensor<2x128x64xf32, #dst> tt.return } } diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index 9e6acf2e4b..c0d4ea1edb 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -1,33 +1,33 @@ // RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s -// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #ttg.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}> // CHECK-LABEL: wmma_to_wmma_dot_op -#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !triton_gpu.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #ttg.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } } // ----- -// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #ttg.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}> // CHECK-LABEL: wmma_to_wmma_dot3d_op -#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !triton_gpu.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> - %0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #ttg.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> + %0 = ttg.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } } @@ -35,13 +35,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- // CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} { tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.local_alloc - // CHECK: triton_gpu.convert_layout - // CHECK-NOT: triton_gpu.local_alloc - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + // CHECK-NOT: ttg.local_alloc + // CHECK: ttg.convert_layout + // CHECK-NOT: ttg.local_alloc + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -49,13 +49,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.local_alloc - // CHECK: triton_gpu.convert_layout - // CHECK-NOT: triton_gpu.local_alloc - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + // CHECK-NOT: ttg.local_alloc + // CHECK: ttg.convert_layout + // CHECK-NOT: ttg.local_alloc + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -63,13 +63,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.convert_layout - // CHECK: triton_gpu.local_alloc - // CHECK: triton_gpu.local_load - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -77,14 +77,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.convert_layout - // CHECK: triton_gpu.local_alloc - // CHECK: triton_gpu.local_load - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> tt.return } } @@ -92,14 +92,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.convert_layout - // CHECK: triton_gpu.local_alloc - // CHECK: triton_gpu.local_load - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> tt.return } } diff --git a/test/Conversion/amd/dedup-by-constancy.mlir b/test/Conversion/amd/dedup-by-constancy.mlir index 8340cce6d1..66a224bcef 100644 --- a/test/Conversion/amd/dedup-by-constancy.mlir +++ b/test/Conversion/amd/dedup-by-constancy.mlir @@ -13,13 +13,13 @@ // only allows duplication within each group of 4 elemnets. Therefore, we expect 4 icmp, one // for each group of 4 elements. // In the future, we can reduce the icmp to 2 in such case. -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @dedup_by_constancy_mfma(%arg0: i32 {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %1 = tt.splat %arg0 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %2 = arith.cmpi slt, %0, %1 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi1, #mma> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %1 = tt.splat %arg0 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %2 = arith.cmpi slt, %0, %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi1, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi1, #mma> %4 = tt.broadcast %3 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma> %cst = arith.constant dense<0.100000e+00> : tensor<32x32xf16, #mma> %5 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #mma> diff --git a/test/Conversion/amd/fp_to_fp.mlir b/test/Conversion/amd/fp_to_fp.mlir index aaa70564fd..959158ab49 100644 --- a/test/Conversion/amd/fp_to_fp.mlir +++ b/test/Conversion/amd/fp_to_fp.mlir @@ -1,11 +1,11 @@ // RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s // CHECK-LABEL: f16_to_f32 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func @f16_to_f32(%arg0: tensor<8x8xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @f16_to_f32(%arg0: tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>) { // CHECK-COUNT-8: llvm.inline_asm asm_dialect {{.*}}v_cvt_f32_f16 {{.*}}: (f16) -> f32 - %0 = tt.fp_to_fp %arg0 : tensor<8x8xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<8x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %0 = tt.fp_to_fp %arg0 : tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -13,11 +13,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: bf16_to_f32 -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func @bf16_to_f32(%arg0: tensor<8x8xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>) { +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @bf16_to_f32(%arg0: tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) { // CHECK-COUNT-8: llvm.bitcast - %0 = tt.fp_to_fp %arg0 : tensor<8x8xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> + %0 = tt.fp_to_fp %arg0 : tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> tt.return } } diff --git a/test/Conversion/amd/invalid_extractslice_to_llvm.mlir b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir index e561dfb269..9730f9eace 100644 --- a/test/Conversion/amd/invalid_extractslice_to_llvm.mlir +++ b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics // Invalid size -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTATile [256, 16]}} %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1> @@ -11,7 +11,7 @@ tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibili // ----- // Invalid zero source dimension -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_size_input(%arg0: tensor<256x0xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{source tensor dimension size zero at dimension 1}} %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x0xi32, #blocked1> to tensor<256x16xi32, #blocked1> @@ -21,7 +21,7 @@ tt.func @invalid_size_input(%arg0: tensor<256x0xi32, #blocked1> {tt.divisibility // ----- // Invalid zero result dimension -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{result tensor dimension size zero at dimension 1}} %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x0xi32, #blocked1> @@ -31,7 +31,7 @@ tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibili // ----- // Invalid offset, not multiple of shapePerTile -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTATile [256, 16]}} %1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> @@ -41,7 +41,7 @@ tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibi // ----- // Invalid offset, out of bounds for dimension -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{invalid offset 128 at dimension 1}} %1 = amdgpu.extract_slice %arg0 [0,128] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> @@ -51,8 +51,8 @@ tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibi // ----- // Invalid result layout -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{result layout must match source layout}} %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2> @@ -62,7 +62,7 @@ tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisib // ----- // Invalid result element type -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{result element type must match source element type}} %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1> @@ -72,7 +72,7 @@ tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.d // ----- // Invalid result rank -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{result rank must be equal to source rank}} %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> @@ -82,7 +82,7 @@ tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibil // ----- // Invalid result shape -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{result shape cannot be larger than input shape at dimension 1}} %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x256xi32, #blocked1> @@ -92,7 +92,7 @@ tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibil // ----- // Invalid rank -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) { // expected-error @+1 {{currently only 2D tensors are supported}} %1 = amdgpu.extract_slice %arg0 [0,0,0] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> @@ -102,7 +102,7 @@ tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = // ----- // Invalid non static offset -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { // expected-error @+2 {{expected ']'}} // expected-error @+1 {{expected integer value}} diff --git a/test/Conversion/amd/load_store.mlir b/test/Conversion/amd/load_store.mlir index 57b7c500d2..25336d2552 100644 --- a/test/Conversion/amd/load_store.mlir +++ b/test/Conversion/amd/load_store.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec8 tt.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -30,21 +30,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: global_store_mfma_vec16 tt.func public @global_store_mfma_vec16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> %1 = math.exp2 %0 : tensor<32x32xf32, #mma> %2 = arith.truncf %1 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> %c32_i32 = arith.constant 32 : i32 %100 = tt.get_program_id x : i32 %101 = arith.muli %100, %c32_i32 : i32 - %102 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %300 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma> + %102 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %300 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma> %200 = tt.broadcast %300 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma> %103 = tt.splat %101 : i32 -> tensor<32x32xi32, #mma> %104 = arith.addi %103, %200 : tensor<32x32xi32, #mma> diff --git a/test/Conversion/amd/math-denorm-handling.mlir b/test/Conversion/amd/math-denorm-handling.mlir index 520f44db93..86c08ca2ae 100644 --- a/test/Conversion/amd/math-denorm-handling.mlir +++ b/test/Conversion/amd/math-denorm-handling.mlir @@ -2,8 +2,8 @@ // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_NO_FTZ -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { // LLVM_FTZ: llvm.amdgcn.exp2.f32 // LLVM_NO_FTZ: llvm.exp2.f32 @@ -14,8 +14,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { // LLVM_FTZ: llvm.exp2.f32 // LLVM_NO_FTZ: llvm.exp2.f32 diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index a2c8f48718..9a9764d992 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,29 +1,29 @@ // RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: shortcut_mfma16 tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load // CHECK: llvm.return - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } } // ----- -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: no_shortcut_mfma16 tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK: store // CHECK: load // CHECK: llvm.return - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } } diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 98d97f5cce..bd0e86bc1a 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.cond_br @@ -18,8 +18,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.cond_br @@ -36,35 +36,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // Smoke test to check that mfma 32 and dot operand layouts can work with small tensors, for example with shape 16x16 -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> -#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> -#dotop1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +#dotop1 = #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: small_mfma_tensor_conversions tt.func public @small_mfma_tensor_conversions(%arg0: tensor<16x16xf16, #mfma>, %arg1: tensor<16x16x!tt.ptr, #mfma>) { - // CHECK-NOT: triton_gpu.convert_layout - %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + // CHECK-NOT: ttg.convert_layout + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> // CHECK-4: store {{.*}} vector<4xf16> - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop0> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #dotop0> // CHECK-2: load {{.*}} vector<4xf16> - %2 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop1> + %2 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #dotop1> // CHECK-8: load {{.*}} vector<1xf16> - %3 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #mfma> + %3 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #mfma> // CHECK-4: load {{.*}} vector<4xf16> %4 = tt.fp_to_fp %3 : tensor<16x16xf16, #mfma> -> tensor<16x16xf32, #mfma> %5 = tt.dot %1, %2, %4 : tensor<16x16xf16, #dotop0> * tensor<16x16xf16, #dotop1> -> tensor<16x16xf32, #mfma> // Store result to prevent DCE from removing all conversion related code - %6 = triton_gpu.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !triton_gpu.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> + %6 = ttg.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !ttg.memdesc<16x16xf32, #shared, #ttg.shared_memory> tt.return } } // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f16x2 tt.func @atomic_add_f16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> @@ -81,8 +81,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_bf16x2 tt.func @atomic_add_bf16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> @@ -99,8 +99,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f16_dpp tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> @@ -117,8 +117,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_bf16_dpp tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> @@ -135,8 +135,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked3 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: reduce_dpp_max tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) { // CHECK: rocdl.update.dpp @@ -175,8 +175,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked4 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: reduce_xor_max tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { // CHECK: rocdl.ds_swizzle diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index e7dcb873d0..032b1e6fe0 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -1,37 +1,37 @@ // RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -#mma2 = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#mma2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma1_dot_operand - tt.func @wmma1_dot_operand(%arg0: !triton_gpu.memdesc<64x64xf16, #shared>) { + tt.func @wmma1_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared>) { // 2 CTA * 4 rep * load_per_thread_per_instr // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-128: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } // CHECK-LABEL: wmma2_dot_operand - tt.func @wmma2_dot_operand(%arg0: !triton_gpu.memdesc<64x64xf16, #shared>) { + tt.func @wmma2_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared>) { // 2 CTA * 4 rep * load_per_thread_per_instr // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> - %0 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> // CHECK-COUNT-64: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> tt.return } // CHECK-LABEL: wmma1_dot - tt.func @wmma1_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) { + tt.func @wmma1_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) { // CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK: llvm.mlir.undef : vector<16xf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16> // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> @@ -39,7 +39,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: wmma1_dot_bf16 - tt.func @wmma1_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma1>) { + tt.func @wmma1_dot_bf16(%arg0: tensor<16x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> // CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16> // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> @@ -48,12 +48,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.undef : vector<16xbf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16> // CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1> tt.return } // CHECK-LABEL: wmma1_dot_int8_32 - tt.func @wmma1_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { + tt.func @wmma1_dot_int8_32(%arg0: tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8> // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> @@ -62,13 +62,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } // CHECK-LABEL: wmma1_dot_int4_32 - tt.func @wmma1_dot_int4_32(%arg0: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { + tt.func @wmma1_dot_int4_32(%arg0: tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4> // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> @@ -77,13 +77,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } // CHECK-LABEL: wmma2_dot - tt.func @wmma2_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>>, %arg2: tensor<16x16xf16, #mma2>) { + tt.func @wmma2_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>>, %arg2: tensor<16x16xf16, #mma2>) { // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK: llvm.mlir.undef : vector<8xf16> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> @@ -91,7 +91,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK: llvm.mlir.undef : vector<8xf16> // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v8f16"{{.*}} : (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> -> tensor<16x16xf16, #mma2> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> -> tensor<16x16xf16, #mma2> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf16> // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> @@ -101,20 +101,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> -#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> +#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_dot_operand3d - tt.func @wmma_dot_operand3d(%arg0: !triton_gpu.memdesc<4x16x32xf16, #shared>) { + tt.func @wmma_dot_operand3d(%arg0: !ttg.memdesc<4x16x32xf16, #shared>) { // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } // CHECK-LABEL: wmma_dot3d - tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma1>) { + tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma1>) { // CHECK-COUNT-32: llvm.extractvalue %arg0 // CHECK-COUNT-32: llvm.insertelement // CHECK-COUNT-32: llvm.extractvalue %arg1 @@ -122,7 +122,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-8: llvm.extractvalue %arg2 // CHECK-COUNT-8: llvm.insertelement // CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1> // CHECK-COUNT-8: llvm.extractelement // CHECK-COUNT-8: llvm.insertvalue tt.return diff --git a/test/Conversion/dedup-by-constancy.mlir b/test/Conversion/dedup-by-constancy.mlir index 96131eae87..dc2cda84a7 100644 --- a/test/Conversion/dedup-by-constancy.mlir +++ b/test/Conversion/dedup-by-constancy.mlir @@ -10,8 +10,8 @@ // CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]] // CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]] // CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]] -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<256> : tensor<1024xi32, #blocked> %c1024_i32 = arith.constant 1024 : i32 @@ -48,8 +48,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]] // CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]] // CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]] -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<4> : tensor<1024xi32, #blocked> %c1024_i32 = arith.constant 1024 : i32 diff --git a/test/Conversion/divide-by-0.mlir b/test/Conversion/divide-by-0.mlir index 8f920fcc05..f12fd1bc78 100644 --- a/test/Conversion/divide-by-0.mlir +++ b/test/Conversion/divide-by-0.mlir @@ -3,12 +3,12 @@ // CHECK-LABEL: dont_divide_0 // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-NOT: llvm.urem %{{.*}}, %[[C0]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @dont_divide_0() attributes {noinline = false} { %zero = arith.constant dense<0.000000e+00> : tensor<16x1xf32, #mma> - %cvt = triton_gpu.convert_layout %zero : tensor<16x1xf32, #mma> -> tensor<16x1xf32, #blocked> + %cvt = ttg.convert_layout %zero : tensor<16x1xf32, #mma> -> tensor<16x1xf32, #blocked> tt.return } } diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 96482b2298..ffa74eba0e 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=cuda:80 num-warps=2' | FileCheck %s -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @ops() { - // CHECK: module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {{.*}} + // CHECK: module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {{.*}} %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> @@ -13,7 +13,7 @@ tt.func @ops() { // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // Test if LoadOp is lowered properly (see #771) %ptrs = tt.splat %ptr : !tt.ptr -> tensor<128x!tt.ptr> @@ -34,36 +34,36 @@ tt.func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // Test if the total number of threadsPerWarp is 32 // Test if the total number of warps is 2 - // CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> - // CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> - // CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> - // CHECK: module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {{.*}} + // CHECK: #[[blocked0:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: #[[blocked1:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: #[[blocked2:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {{.*}} %c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32> %c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32> %c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32> - // CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>> + // CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #ttg.slice<{dim = 0, parent = #[[blocked0]]}>> %c0_ = "tt.reduce" (%c0) ({ ^bb0(%arg1: f32, %arg2: f32): %add = arith.addf %arg1, %arg2 : f32 tt.reduce.return %add : f32 }) {axis = 0 : i32} : (tensor<4x4xf32>) -> tensor<4xf32> - // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #ttg.slice<{dim = 0, parent = #[[blocked1]]}> %c1_ = "tt.reduce" (%c1) ({ ^bb0(%arg3: f32, %arg4: f32): %add = arith.addf %arg3, %arg4 : f32 tt.reduce.return %add : f32 }) {axis = 0 : i32} : (tensor<8x2xf32>) -> tensor<2xf32> - // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #ttg.slice<{dim = 1, parent = #[[blocked1]]}>> %c2_ = "tt.reduce" (%c1) ({ ^bb0(%arg5: f32, %arg6: f32): %add = arith.addf %arg5, %arg6 : f32 tt.reduce.return %add : f32 }) {axis = 1 : i32} : (tensor<8x2xf32>) -> tensor<8xf32> - // CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>> + // CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[blocked2]]}>> %c3_ = "tt.reduce" (%c2) ({ ^bb0(%arg7: f32, %arg8: f32): %add = arith.addf %arg7, %arg8 : f32 @@ -77,7 +77,7 @@ tt.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func public @select_op(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i1) attributes {noinline = false} { // CHECK-LABEL: select_op %cst = arith.constant dense<0.000000e+00> : tensor<128xf32> @@ -98,7 +98,7 @@ tt.func public @select_op(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @arith_splat_bool(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // CHECK-LABEL: arith_splat_bool diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3f2fd578da..954b1d349e 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) // Here the 128 comes from the 4 in module attribute multiples 32 // CHECK: nvvm.kernel = 1 : ui1, nvvm.reqntid = array @@ -12,8 +12,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_load tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -27,8 +27,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: vectorized_load tt.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -42,8 +42,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: vectorized_load_f16 tt.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { // CHECK: llvm.inline_asm @@ -58,8 +58,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- // TODO: masked load with vectorization is pending on TODO -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: masked_load_const_other tt.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> @@ -71,8 +71,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- // TODO: masked load with vectorization is pending on TODO -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: masked_load_const_other_vec tt.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> @@ -83,8 +83,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: store_with_cache_attr tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -98,8 +98,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { // CHECK-LABEL: global_load_store_no_vec tt.func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -150,8 +150,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { // CHECK-LABEL: global_load_store_vec4 tt.func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -187,9 +187,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- // This test verifies the vectorization of Load and Store Ops. -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> // Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1. -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { %c64_i32 = arith.constant 64 : i32 %0 = tt.get_program_id x : i32 @@ -217,8 +217,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec2 tt.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -262,8 +262,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec2 tt.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -307,8 +307,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec8 tt.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -349,9 +349,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_view_broadcast tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { // CHECK: llvm.mlir.undef @@ -374,8 +374,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: basic_make_range tt.func @basic_make_range() { // CHECK: nvvm.read.ptx.sreg.tid.x @@ -389,8 +389,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addf tt.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { // CHECK: llvm.fadd @@ -402,8 +402,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addi tt.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.add @@ -415,7 +415,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_program_id tt.func @basic_program_id() { // CHECK: llvm.inline_asm asm_dialect = att operand_attrs = [] "mov.u32 $0, %ctaid.x;", "=r" : () -> i32 @@ -426,8 +426,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addptr tt.func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.getelementptr @@ -439,23 +439,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_alloc_tensor tt.func @basic_alloc_tensor() { // CHECK: llvm.mlir.addressof @global_smem // CHECK-NEXT: llvm.getelementptr // CHECK-NEXT: llvm.mlir.constant - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory, mutable> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_subview tt.func @basic_subview() { @@ -477,30 +477,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: llvm.getelementptr %index = arith.constant 1 : i32 %zero = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !triton_gpu.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16x32xf32, #shared0, #ttg.shared_memory, mutable> + %1 = ttg.memdesc_subview %0[%index, %zero, %zero] : !ttg.memdesc<128x16x32xf32, #shared0, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #ttg.shared_memory, mutable> tt.return } } // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_async_wait tt.func @basic_async_wait() { // CHECK: cp.async.wait_group 0x4 - triton_gpu.async_wait {num = 4: i32} + ttg.async_wait {num = 4: i32} tt.return } } // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#slice1d0 = #triton_gpu.slice<{dim = 0, parent = #blocked1}> -#shared1D = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}> -#shared2D = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#slice1d0 = #ttg.slice<{dim = 0, parent = #blocked1}> +#shared1D = #ttg.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}> +#shared2D = #ttg.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: basic_insert_slice_async_1d tt.func @basic_insert_slice_async_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { %c0_i32 = arith.constant 0 : i32 @@ -509,10 +509,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0> %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> - %71 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<2x64xi64, #shared2D, #triton_gpu.shared_memory, mutable> - %subview = triton_gpu.memdesc_subview %71[%c0_i32, %c0_i32] : - !triton_gpu.memdesc<2x64xi64, #shared2D, #triton_gpu.shared_memory, mutable> -> - !triton_gpu.memdesc<64xi64, #shared1D, #triton_gpu.shared_memory, mutable> + %71 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared2D, #ttg.shared_memory, mutable> + %subview = ttg.memdesc_subview %71[%c0_i32, %c0_i32] : + !ttg.memdesc<2x64xi64, #shared2D, #ttg.shared_memory, mutable> -> + !ttg.memdesc<64xi64, #shared1D, #ttg.shared_memory, mutable> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 @@ -523,23 +523,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.commit_group - %73 = triton_gpu.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr, #slice1d0> -> !triton_gpu.memdesc<64xi64, #shared1D, #triton_gpu.shared_memory, mutable> - triton_gpu.async_commit_group %73 + %73 = ttg.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr, #slice1d0> -> !ttg.memdesc<64xi64, #shared1D, #ttg.shared_memory, mutable> + ttg.async_commit_group %73 tt.return } } // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> -#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v4 tt.func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> @@ -551,35 +551,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = tt.splat %cst_scalar : i32 -> tensor<16x64xi32, #block2> %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2> %broadcast_off1_ = tt.broadcast %off1 : tensor<1x64xi32, #block3> -> tensor<16x64xi32, #block3> - %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : tensor<16x64xi32, #block2> -> tensor<16x64xi32, #AL> - %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : tensor<16x64xi32, #block3> -> tensor<16x64xi32, #AL> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<16x64xi32, #block2> -> tensor<16x64xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<16x64xi32, #block3> -> tensor<16x64xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x64x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x64xf32, #A, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x64xf32, #A, #ttg.shared_memory, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;" // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;" // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !triton_gpu.memdesc<16x64xf32, #A, #triton_gpu.shared_memory, mutable> - triton_gpu.async_commit_group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !ttg.memdesc<16x64xf32, #A, #ttg.shared_memory, mutable> + ttg.async_commit_group tt.return } } // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> -#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1 tt.func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> @@ -591,12 +591,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = tt.splat %cst_scalar : i32 -> tensor<16x32xi32, #block2> %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x32xi32, #block2> %broadcast_off1_ = tt.broadcast %off1 : tensor<1x32xi32, #block3> -> tensor<16x32xi32, #block3> - %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : tensor<16x32xi32, #block2> -> tensor<16x32xi32, #AL> - %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : tensor<16x32xi32, #block3> -> tensor<16x32xi32, #AL> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<16x32xi32, #block2> -> tensor<16x32xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<16x32xi32, #block3> -> tensor<16x32xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL>, tensor<16x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<16x32xf32, #A, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf32, #A, #ttg.shared_memory, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm @@ -609,22 +609,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !triton_gpu.memdesc<16x32xf32, #A, #triton_gpu.shared_memory, mutable> - triton_gpu.async_commit_group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !ttg.memdesc<16x32xf32, #A, #ttg.shared_memory, mutable> + ttg.async_commit_group tt.return } } // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> -#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1_multictas tt.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1> @@ -636,12 +636,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = tt.splat %cst_scalar : i32 -> tensor<32x32xi32, #block2> %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<32x32xi32, #block2> %broadcast_off1_ = tt.broadcast %off1 : tensor<1x32xi32, #block3> -> tensor<32x32xi32, #block3> - %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : tensor<32x32xi32, #block2> -> tensor<32x32xi32, #AL> - %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : tensor<32x32xi32, #block3> -> tensor<32x32xi32, #AL> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<32x32xi32, #block2> -> tensor<32x32xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<32x32xi32, #block3> -> tensor<32x32xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x32xf32, #A, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #A, #ttg.shared_memory, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.mlir.constant(0 : i32) : i32 @@ -665,16 +665,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !triton_gpu.memdesc<32x32xf32, #A, #triton_gpu.shared_memory, mutable> - triton_gpu.async_commit_group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !ttg.memdesc<32x32xf32, #A, #ttg.shared_memory, mutable> + ttg.async_commit_group tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: basic_splat tt.func @basic_splat(%ptr: !tt.ptr) { // CHECK: llvm.mlir.undef @@ -687,8 +687,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_store tt.func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { // CHECK: llvm.inline_asm @@ -702,9 +702,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_blocked tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { @@ -712,16 +712,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared // CHECK-: nvvm.barrier0 // CHECK-COUNT-8: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_blocked_vec tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { @@ -733,16 +733,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: nvvm.barrier0 // CHECK: llvm.load // CHECK: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { @@ -758,29 +758,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: nvvm.barrier0 // CHECK: llvm.load // CHECK: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { - %AA = triton_gpu.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !triton_gpu.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> - %BB = triton_gpu.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !triton_gpu.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> + %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> + %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> // CHECK: llvm.inline_asm // CHECK: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %AA_DOT = triton_gpu.local_load %AA : !triton_gpu.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_a> - %BB_DOT = triton_gpu.local_load %BB : !triton_gpu.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_b> + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> // CHECK: llvm.inline_asm @@ -794,30 +794,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : } // TODO: problems in MLIR's parser on slice layout -// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -// module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +// #blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +// module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // tt.func @make_range_sliced_layout() { -// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> +// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked0}>> // tt.return // } // } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=4}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=4}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=4}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot_fp8 tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) { - %AA = triton_gpu.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !triton_gpu.memdesc<16x16xf8E5M2, #shared0, #triton_gpu.shared_memory> - %BB = triton_gpu.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !triton_gpu.memdesc<16x16xf8E5M2, #shared0, #triton_gpu.shared_memory> + %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #ttg.shared_memory> + %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #ttg.shared_memory> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %AA_DOT = triton_gpu.local_load %AA : !triton_gpu.memdesc<16x16xf8E5M2, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf8E5M2, #dot_operand_a> - %BB_DOT = triton_gpu.local_load %BB : !triton_gpu.memdesc<16x16xf8E5M2, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf8E5M2, #dot_operand_b> + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #ttg.shared_memory> -> tensor<16x16xf8E5M2, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #ttg.shared_memory> -> tensor<16x16xf8E5M2, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> // CHECK: llvm.inline_asm @@ -832,9 +832,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_mmav2_block tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { @@ -844,209 +844,209 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: st.shared // CHECK: nvvm.barrier0 // CHECK: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> + %0 = ttg.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_layout_mmav2_dot_reg tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_layout_mmav2_dot_reg tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> + %0 = ttg.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#slice = #triton_gpu.slice<{dim = 0, parent = #mma}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice = #ttg.slice<{dim = 0, parent = #mma}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> + %0 = ttg.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> tt.return } } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_mmav3_mmav3_0 tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> tt.return } } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_mmav3_mmav3_1 tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> tt.return } } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_mmav3_mmav3_2 tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> tt.return } } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_mmav3_mmav3_3 tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + %0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_layout_mmav2_dot_reg tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> tt.return } } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_mmav3_mmav3_0 tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> tt.return } } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_mmav3_mmav3_1 tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> tt.return } } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_mmav3_mmav3_2 tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> tt.return } } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: convert_layout_mmav3_mmav3_3 tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { // CHECK-NOT: st.shared // CHECK-NOT: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + %0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_mmav3_transpose tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) { // CHECK-COUNT-128: st.shared.b8 // CHECK: nvvm.barrier0 // CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32> - %0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> + %0 = ttg.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_shared tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { @@ -1054,42 +1054,42 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.store // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !triton_gpu.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice0 tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { // CHECK: llvm.load {{.*}} -> vector<4xi32> - %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %cvt = ttg.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { // CHECK-COUNT-8: llvm.load {{.*}} -> i32 - %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %cvt = ttg.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked_to_blocked_ptr tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { // CHECK: llvm.ptrtoint @@ -1097,28 +1097,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: nvvm.barrier0 // CHECK: llvm.inttoptr // CHECK-COUNT-4: llvm.insertvalue - %cvt = triton_gpu.convert_layout %src : tensor<32x!tt.ptr, #blocked0> -> tensor<32x!tt.ptr, #blocked1> + %cvt = ttg.convert_layout %src : tensor<32x!tt.ptr, #blocked0> -> tensor<32x!tt.ptr, #blocked1> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory>, %b:!triton_gpu.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory>) { + %a:!ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory>, %b:!ttg.memdesc<32x256xf16, #shared, #ttg.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x32xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !triton_gpu.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<32x256xf16, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> -> tensor<128x32xf16, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #ttg.shared_memory> -> tensor<32x256xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> - %38 = triton_gpu.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> + %38 = ttg.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> @@ -1129,17 +1129,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#blocked}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!triton_gpu.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!triton_gpu.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { + %a:!ttg.memdesc<32x16xf32, #shared, #ttg.shared_memory>, %b:!ttg.memdesc<16x32xf32, #shared, #ttg.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> // CHECK: llvm.intr.fmuladd - %a_mat = triton_gpu.local_load %a : !triton_gpu.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !triton_gpu.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #ttg.shared_memory> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #ttg.shared_memory> -> tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> @@ -1151,15 +1151,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=1}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32dot tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!triton_gpu.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!triton_gpu.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { + %a:!ttg.memdesc<32x16xf32, #shared, #ttg.shared_memory>, %b:!ttg.memdesc<16x32xf32, #shared, #ttg.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 @@ -1167,8 +1167,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 // CHECK-SAME: (i32, i32, i32, i32) - %a_mat = triton_gpu.local_load %a : !triton_gpu.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !triton_gpu.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #ttg.shared_memory> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #ttg.shared_memory> -> tensor<16x32xf32, #dot_operand_b> // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 @@ -1179,7 +1179,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> - %38 = triton_gpu.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %38 = ttg.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> @@ -1190,8 +1190,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1205,7 +1205,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.icmp "eq" @@ -1218,8 +1218,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1233,8 +1233,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f16_nomask // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 @@ -1246,8 +1246,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f16_withmask // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 @@ -1261,8 +1261,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 tt.func @store_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1276,7 +1276,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32_scalar tt.func @store_f32_scalar(%arg0 : !tt.ptr, %arg1 : f32) { // CHECK: llvm.icmp "eq" @@ -1289,8 +1289,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_get_program_id tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { %blockidx = tt.get_program_id x: i32 @@ -1311,8 +1311,8 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_get_program_id tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { %blockidx = tt.get_program_id x: i32 @@ -1333,8 +1333,8 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_get_num_program tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { %blockdimx = tt.get_num_programs x : i32 @@ -1354,8 +1354,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { %blockdimx = tt.get_num_programs x : i32 %blockdimy = tt.get_num_programs y : i32 @@ -1373,8 +1373,8 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_index_cache tt.func @test_index_cache() { // CHECK: nvvm.read.ptx.sreg.tid.x @@ -1385,29 +1385,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: test_base_index_cache tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !triton_gpu.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !triton_gpu.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> + %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: test_index_cache_different_block tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !triton_gpu.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> cf.cond_br %arg1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !triton_gpu.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 tt.return @@ -1416,12 +1416,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=1}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32_cst_b tt.func @matmul_tf32_cst_b(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a: tensor<32x16xf32, #dot_operand_a>, %c: tensor<32x32xf32, #mma>) { @@ -1431,7 +1431,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> %b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a, %b_mat, %c, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> - %38 = triton_gpu.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %38 = ttg.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> tt.store %36, %38 : tensor<32x32x!tt.ptr, #blocked> @@ -1441,9 +1441,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: matmul_f16_cst_operands tt.func public @matmul_f16_cst_operands(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> @@ -1453,18 +1453,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[V1:.+]] = llvm.insertelement %{{.*}}, %[[V0]][%[[C1]] : i32] : vector<2xf16> // CHECK: %[[BC:.+]] = llvm.bitcast %[[V1]] : vector<2xf16> to i32 - %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked> - %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %4 = arith.muli %3, %cst_2 : tensor<32x1xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %6 = tt.addptr %5, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> - %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %9 = tt.broadcast %6 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> %10 = tt.broadcast %8 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %11 = tt.addptr %9, %10 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> @@ -1475,8 +1475,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: test_s8_to_bf16_conversion tt.func @test_s8_to_bf16_conversion(%in: tensor<32xi8, #blocked>) { // We can't vectorize if we only process @@ -1489,9 +1489,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: test_s8_to_bf16_vectorized_conversion tt.func @test_s8_to_bf16_vectorized_conversion(%in: tensor<16x16xi8, #mma>) { // CHECK-NOT: llvm.sitofp @@ -1513,19 +1513,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: nvvm.shfl.sync bfly // CHECK: nvvm.shfl.sync bfly // CHECK: nvvm.barrier0 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sum_reduction(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<1024> : tensor<1x1xi32, #blocked> %0 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1> - %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi32, #blocked> + %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi32, #blocked> %3 = arith.muli %2, %cst : tensor<1x1xi32, #blocked> %4 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> %5 = tt.addptr %4, %3 : tensor<1x1x!tt.ptr, #blocked>, tensor<1x1xi32, #blocked> - %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x1024xi32, #blocked> + %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x1024xi32, #blocked> %8 = tt.broadcast %5 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1024x!tt.ptr, #blocked> %9 = tt.addptr %8, %7 : tensor<1x1024x!tt.ptr, #blocked>, tensor<1x1024xi32, #blocked> %10 = tt.load %9 : tensor<1x1024x!tt.ptr, #blocked> @@ -1533,8 +1533,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg2: i32, %arg3: i32): %15 = arith.addi %arg2, %arg3 : i32 tt.reduce.return %15 : i32 - }) : (tensor<1x1024xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %12 = triton_gpu.convert_layout %11 : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + }) : (tensor<1x1024xi32, #blocked>) -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %12 = ttg.convert_layout %11 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> %13 = tt.splat %arg1 : !tt.ptr -> tensor<1x!tt.ptr, #blocked1> %14 = tt.addptr %13, %0 : tensor<1x!tt.ptr, #blocked1>, tensor<1xi32, #blocked1> tt.store %14, %12 : tensor<1x!tt.ptr, #blocked1> @@ -1543,9 +1543,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice = #ttg.slice<{dim = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { // CHECK-LABEL: reduce_bools tt.func public @reduce_bools(%arg: tensor<256x2xi1, #blocked>) { // CHECK: llvm.mlir.addressof @global_smem @@ -1561,8 +1561,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: inline_asm tt.func public @inline_asm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> @@ -1580,8 +1580,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: inline_asm_pack_16bit tt.func public @inline_asm_pack_16bit(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> @@ -1602,16 +1602,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-LABEL: reduce_slice // CHECK-NOT: st.shared // CHECK-NOT: ld.shared -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> -#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#sliced2 = #ttg.slice<{dim = 2, parent = #blocked}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @reduce_slice() attributes {noinline = false} { %cst = arith.constant dense : tensor<4x1xi1, #sliced2> %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ ^bb0(%arg0: i1, %arg1: i1): %1 = arith.ori %arg0, %arg1 : i1 tt.reduce.return %1 : i1 - }) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #sliced2}>> + }) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #ttg.slice<{dim = 1, parent = #sliced2}>> tt.return } } @@ -1623,41 +1623,41 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: st.shared // CHECK: ld.shared // CHECK: st.shared -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 2, 2], order = [2, 1, 0]}> -#sliced = #triton_gpu.slice<{dim = 2, parent = #blocked}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 2, 2], order = [2, 1, 0]}> +#sliced = #ttg.slice<{dim = 2, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func public @reduce_md_slice(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<2x128xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<2x128xf32, #ttg.slice<{dim = 2, parent = #blocked}>> %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %18 = arith.maxnumf %arg1, %arg2 : f32 tt.reduce.return %18 : f32 - }) {allocation.offset = 0 : i32} : (tensor<2x128xf32, #sliced>) -> tensor<2xf32, #triton_gpu.slice<{dim = 1, parent = #sliced}>> + }) {allocation.offset = 0 : i32} : (tensor<2x128xf32, #sliced>) -> tensor<2xf32, #ttg.slice<{dim = 1, parent = #sliced}>> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @i16_mma_layout(%f16_inp: tensor<16x16xf16, #blocked0>, %i16_inp: tensor<16x16xi16, #blocked0>) { // CHECK-LABEL: @i16_mma_layout - %f16_shared = triton_gpu.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !triton_gpu.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> - %i16_shared = triton_gpu.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !triton_gpu.memdesc<16x16xi16, #shared0, #triton_gpu.shared_memory> + %f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> + %i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #ttg.shared_memory> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %f16_dot = triton_gpu.local_load %f16_shared : !triton_gpu.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_a> - %i16_dot = triton_gpu.local_load %i16_shared : !triton_gpu.memdesc<16x16xi16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xi16, #dot_operand_b> + %f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> -> tensor<16x16xf16, #dot_operand_a> + %i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #ttg.shared_memory> -> tensor<16x16xi16, #dot_operand_b> // CHECK: llvm.sitofp %{{.*}} : i16 to f16 @@ -1677,9 +1677,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: convert_single_element // CHECK-NOT: llvm.store // CHECK-NOT: llvm.load @@ -1687,16 +1687,16 @@ module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : // CHECK: llvm.extractvalue tt.func public @convert_single_element() attributes {noinline = false} { %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> - %0 = triton_gpu.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> + %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: convert_single_element_and_add // CHECK-NOT: llvm.store // CHECK-NOT: llvm.load @@ -1705,7 +1705,7 @@ module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : tt.func public @convert_single_element_and_add() attributes {noinline = false} { %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> %cst2 = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked> - %0 = triton_gpu.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> + %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> %1 = arith.addf %0, %cst2 : tensor<1xf32, #blocked> tt.return } @@ -1713,38 +1713,38 @@ module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : // ----- -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @vectorize_shmem_load // CHECK: llvm.load // CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8> // CHECK-NOT: llvm.load - tt.func public @vectorize_shmem_load(%shmem : !triton_gpu.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory>) { - %0 = triton_gpu.local_load %shmem : !triton_gpu.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory> -> tensor<16x16xi8, #blocked> + tt.func public @vectorize_shmem_load(%shmem : !ttg.memdesc<16x16xi8, #shared, #ttg.shared_memory>) { + %0 = ttg.local_load %shmem : !ttg.memdesc<16x16xi8, #shared, #ttg.shared_memory> -> tensor<16x16xi8, #blocked> tt.return } } // ----- -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @vectorize_shmem_store // CHECK: llvm.store // CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3> // CHECK-NOT: llvm.store tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) { - %0 = triton_gpu.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !triton_gpu.memdesc<64x64xi32, #shared, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !ttg.memdesc<64x64xi32, #shared, #ttg.shared_memory> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: abs_is_int_min_poison // CHECK: %{{.*}} = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32 tt.func @abs_is_int_min_poison(%arg0 : tensor<256xi32, #blocked0>) { @@ -1754,54 +1754,54 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_load_bf16 // CHECK: llvm.extractelement {{.*}} : vector<8xbf16> tt.func public @test_local_load_bf16() { %c0_i32 = arith.constant 0 : i32 - %19 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> - %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> - %39 = triton_gpu.local_load %22 : !triton_gpu.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<1x2048xbf16, #blocked> + %19 = ttg.local_alloc : () -> !ttg.memdesc<1x1x2048xbf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x1x2048xbf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<1x2048xbf16, #shared, #ttg.shared_memory, mutable> + %39 = ttg.local_load %22 : !ttg.memdesc<1x2048xbf16, #shared, #ttg.shared_memory, mutable> -> tensor<1x2048xbf16, #blocked> %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_store // CHECK: llvm.store tt.func public @test_local_store(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !triton_gpu.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !triton_gpu.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> + ttg.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_store_subview // CHECK: llvm.store tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !triton_gpu.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> - %sv = triton_gpu.memdesc_subview %0[%c0_i32] : !triton_gpu.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !triton_gpu.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> + %sv = ttg.memdesc_subview %0[%c0_i32] : !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> + ttg.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: print_ptr // CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 tt.func @print_ptr(%arg0 : tensor<256x!tt.ptr, #blocked0>) { @@ -1811,8 +1811,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // Test that %u format specifier is used if isSigned is false // CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %u{{.*}}") // CHECK-LABEL: print_int32_tensor_issigned_off @@ -1824,8 +1824,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // Test that %i format specifier is used if isSigned is true // CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %i{{.*}}") // CHECK-LABEL: print_int32_tensor_issigned_on @@ -1838,8 +1838,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @int32_to_bf16(%arg0: tensor<256xi32, #blocked>) attributes {noinline = false} { // CHECK-LABEL: @int32_to_bf16 // CHECK: llvm.sitofp %{{.*}} : i32 to bf16 @@ -1850,8 +1850,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @bf16_to_int32(%arg0: tensor<256xbf16, #blocked>) attributes {noinline = false} { // CHECK-LABEL: @bf16_to_int32 // CHECK: llvm.fptosi %{{.*}} : bf16 to i32 @@ -1862,13 +1862,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32} // CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32} // CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32} // CHECK: llvm.call @__assertfail // CHECK: nvvm.barrier0 -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) { tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5) tt.return @@ -1882,8 +1882,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @log1pf_scan(%39: tensor<32x16xf32, #blocked>) attributes {noinline = false} { // CHECK: log1pf_scan // non-speculatable ops will introduce a cond_br; extern_elementwise with pure = true should be considered speculatable. diff --git a/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir index 49128064a8..f45143678c 100644 --- a/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir +++ b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir @@ -1,10 +1,10 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s // CHECK-LABEL: blocked_to_dot_op_shortcut_warp32 -#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> +#blocked = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>) { + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> // CHECK-NOT: load tt.return } @@ -13,10 +13,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: blocked_to_dot_op_shortcut_warp64 -#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> // CHECK-NOT: load tt.return } @@ -25,10 +25,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32 -#blocked = #triton_gpu.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> // CHECK-NOT: load tt.return } @@ -37,10 +37,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> // CHECK-NOT: load tt.return } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 83eacfa843..1b64ee7005 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,11 +1,11 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' 2>&1 | FileCheck %s -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_high_precision_acc - tt.func @dot_high_precision_acc(%a: !triton_gpu.memdesc<128x128xf8E5M2, #shared>, %b: !triton_gpu.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_high_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd // CHECK: nvgpu.wgmma @@ -14,21 +14,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-128: llvm.fadd // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c + %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 32 : i32, inputPrecision = 0 : i32} : - !triton_gpu.memdesc<128x128xf8E5M2, #shared> * !triton_gpu.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared> * !ttg.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_low_precision_acc - tt.func @dot_low_precision_acc(%a: !triton_gpu.memdesc<128x128xf8E5M2, #shared>, %b: !triton_gpu.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_low_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: nvgpu.wgmma @@ -38,21 +38,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: llvm.return - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c + %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 129 : i32, inputPrecision = 0 : i32} : - !triton_gpu.memdesc<128x128xf8E5M2, #shared> * !triton_gpu.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared> * !ttg.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_mix_precision_acc - tt.func @dot_mix_precision_acc(%a: !triton_gpu.memdesc<128x128xf8E5M2, #shared>, %b: !triton_gpu.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_mix_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: nvgpu.wgmma @@ -62,86 +62,86 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd // CHECK: llvm.return - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c + %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : - !triton_gpu.memdesc<128x128xf8E5M2, #shared> * !triton_gpu.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared> * !ttg.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_zero_acc // Generate a wgmma with 2 sources. // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { - tt.func @dot_zero_acc(%a: !triton_gpu.memdesc<128x64xf16, #shared>, %b: !triton_gpu.memdesc<64x64xf16, #shared1>) { + tt.func @dot_zero_acc(%a: !ttg.memdesc<128x64xf16, #shared>, %b: !ttg.memdesc<64x64xf16, #shared1>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : - !triton_gpu.memdesc<128x64xf16, #shared> * !triton_gpu.memdesc<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %m = ttng.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : + !ttg.memdesc<128x64xf16, #shared> * !ttg.memdesc<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> - tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !triton_gpu.memdesc<64x64xf16, #shared>) { + tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !ttg.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: - tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %opA = ttg.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %m = ttng.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: + tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A_fp8 // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} - tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !triton_gpu.memdesc<128x256xf8E5M2, #shared>) { + tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !ttg.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : - tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !triton_gpu.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + %m = ttng.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : + tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: dot_reg_operand_upcast - tt.func @dot_reg_operand_upcast(%a_desc: !triton_gpu.memdesc<128x64xi8, #shared>, %b: !triton_gpu.memdesc<64x64xf16, #shared>, %acc: tensor<128x64xf32, #mma>) { - %a_dotop = triton_gpu.local_load %a_desc : !triton_gpu.memdesc<128x64xi8, #shared> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %res = triton_nvidia_gpu.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @dot_reg_operand_upcast(%a_desc: !ttg.memdesc<128x64xi8, #shared>, %b: !ttg.memdesc<64x64xf16, #shared>, %acc: tensor<128x64xf32, #mma>) { + %a_dotop = ttg.local_load %a_desc : !ttg.memdesc<128x64xi8, #shared> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %res = ttng.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: test_fp8_to_f16_conversion tt.func @test_fp8_to_f16_conversion( %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>, @@ -168,9 +168,9 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> // CHECK-LABEL: clamp -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @clamp(%x : tensor<1024xf32, #blocked>, %limit : tensor<1024xf32, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked> %neg_limit = arith.subf %cst, %limit : tensor<1024xf32, #blocked> @@ -183,23 +183,23 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 16]}> // CHECK-LABEL: convert_mma_to_blocked -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @convert_mma_to_blocked(%a: tensor<128x256xf16, #mma>) { // CHECK-COUNT-16: nvgpu.stmatrix // CHECK: nvvm.barrier0 - %c = triton_gpu.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> + %c = ttg.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: cvt_mma_to_dot_fp8 // CHECK: prmt.b32 // CHECK: prmt.b32 @@ -208,23 +208,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: prmt.b32 // CHECK: prmt.b32 tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { - %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %opA = ttg.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: dot_zero_acc_operand // CHECK-COUNT-128: llvm.fadd - tt.func @dot_zero_acc_operand(%a: !triton_gpu.memdesc<128x128xf8E5M2, #shared>, %b: !triton_gpu.memdesc<128x128xf8E5M2, #shared1>) { + tt.func @dot_zero_acc_operand(%a: !ttg.memdesc<128x128xf8E5M2, #shared>, %b: !ttg.memdesc<128x128xf8E5M2, #shared1>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : - !triton_gpu.memdesc<128x128xf8E5M2, #shared> * !triton_gpu.memdesc<128x128xf8E5M2, #shared1> -> tensor<128x128xf32, #mma> + %m = ttng.warp_group_dot %a, %b, %cst {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : + !ttg.memdesc<128x128xf8E5M2, #shared> * !ttg.memdesc<128x128xf8E5M2, #shared1> -> tensor<128x128xf32, #mma> tt.return } } @@ -232,22 +232,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> // CHECK-LABEL: distribute_to_shared_st_matrix -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @distribute_to_shared_st_matrix(%a: tensor<128x128xf16, #mma>) { // CHECK-COUNT-16: nvgpu.stmatrix // CHECK: llvm.return - %b = triton_gpu.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !triton_gpu.memdesc<128x128xf16, #shared, mutable> + %b = ttg.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, mutable> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} { // CHECK-LABEL: @fp8_const // CHECK: llvm.mlir.constant(0.000000e+00 : f8E4M3FNUZ) : i8 @@ -259,8 +259,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f32_nomask // CHECK: atom.global.gpu.acq_rel.add.v4.f32 @@ -271,8 +271,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f32_withmask // CHECK: atom.global.gpu.acq_rel.add.v2.f32 @@ -284,8 +284,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f16_withmask // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 @@ -297,12 +297,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: test_fp8_to_fp16_dot_operand // CHECK-COUNT-16: cvt.rn.f16x2.e5m2x2 - tt.func @test_fp8_to_fp16_dot_operand(%arg: tensor<128x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) { - %r = tt.fp_to_fp %arg : tensor<128x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + tt.func @test_fp8_to_fp16_dot_operand(%arg: tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) { + %r = tt.fp_to_fp %arg : tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir b/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir index 906c610023..1003f321d6 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=80' 2>&1 | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f32_nomask // CHECK: atom.global.gpu.acq_rel.add.f32 @@ -15,8 +15,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f32_withmask // CHECK: atom.global.gpu.acq_rel.add.f32 @@ -30,8 +30,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f16_withmask // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 diff --git a/test/Conversion/tritongpu_to_llvm_volta.mlir b/test/Conversion/tritongpu_to_llvm_volta.mlir index 26010b88bd..a5a4281294 100644 --- a/test/Conversion/tritongpu_to_llvm_volta.mlir +++ b/test/Conversion/tritongpu_to_llvm_volta.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=70 2>&1 | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> // CHECK-LABEL: clamp -module attributes {"triton_gpu.target" = "cuda:70", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:70", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @clamp(%x : tensor<1024xf32, #blocked>, %limit : tensor<1024xf32, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked> %neg_limit = arith.subf %cst, %limit : tensor<1024xf32, #blocked> diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index c7cc5fa5db..52c30b28fc 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -1,25 +1,25 @@ // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: init_barrier - tt.func @init_barrier(%alloc: !triton_gpu.memdesc<1xi64, #shared0>) { + tt.func @init_barrier(%alloc: !ttg.memdesc<1xi64, #shared0>) { // CHECK: "@$0 mbarrier.init.shared::cta.b64 [$1], 1;", "b,r" %{{.*}}, %{{.*}} : (i1, !llvm.ptr<3>) -> !llvm.void - triton_nvidia_gpu.init_barrier %alloc, 1 : !triton_gpu.memdesc<1xi64, #shared0> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: wait_barrier - tt.func @wait_barrier(%alloc: !triton_gpu.memdesc<1xi64, #shared0>, %phase: i32) { + tt.func @wait_barrier(%alloc: !ttg.memdesc<1xi64, #shared0>, %phase: i32) { // CHECK: waitLoop: // CHECK: mbarrier.try_wait.parity.shared.b64 // CHECK: @!P1 bra.uni waitLoop - triton_nvidia_gpu.wait_barrier %alloc, %phase : !triton_gpu.memdesc<1xi64, #shared0> + ttng.wait_barrier %alloc, %phase : !ttg.memdesc<1xi64, #shared0> tt.return } } @@ -27,62 +27,62 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tma_copy_global_to_local // CHECK: elect.sync // CHECK: "@$0 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4}], [$5];", "b,r,l,r,r,r" {{.*}} : (i1, !llvm.ptr<3>, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.shared // CHECK: return - tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !triton_gpu.memdesc<128x128xf32, #shared1, mutable>, %x: i32, %barrier: !triton_gpu.memdesc<1xi64, #shared0>, %pred: i1) { - triton_nvidia_gpu.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !triton_gpu.memdesc<1xi64, #shared0> -> !triton_gpu.memdesc<128x128xf32, #shared1, mutable> + tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0>, %pred: i1) { + ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !ttg.memdesc<1xi64, #shared0> -> !ttg.memdesc<128x128xf32, #shared1, mutable> tt.return } } // ----- -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tma_copy_local_to_global // CHECK: elect.sync // CHECK: "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r" {{.*}} : (i1, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.global.shared::cta.bulk_group // CHECK: cp.async.bulk.commit_group - tt.func @tma_copy_local_to_global(%tma: !tt.ptr, %alloc: !triton_gpu.memdesc<128x128xf32, #shared1>, %x: i32) { - triton_nvidia_gpu.async_tma_copy_local_to_global %tma[%x, %x] %alloc : , <128x128xf32, #shared1> + tt.func @tma_copy_local_to_global(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1>, %x: i32) { + ttng.async_tma_copy_local_to_global %tma[%x, %x] %alloc : , <128x128xf32, #shared1> tt.return } } // ----- -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: async_tma_store_wait // CHECK: "cp.async.bulk.wait_group.read 0x0;", "" : () -> !llvm.void tt.func @async_tma_store_wait() { - triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} + ttng.async_tma_store_wait {pendings = 0 : i32} tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: expect_barrier // CHECK: @$0 mbarrier.arrive.expect_tx.shared.b64 _, [$1], 16384; - tt.func @expect_barrier(%barrier: !triton_gpu.memdesc<1xi64, #shared0, mutable>, %pred: i1) { - triton_nvidia_gpu.barrier_expect %barrier, 16384, %pred : <1xi64, #shared0, mutable> + tt.func @expect_barrier(%barrier: !ttg.memdesc<1xi64, #shared0, mutable>, %pred: i1) { + ttng.barrier_expect %barrier, 16384, %pred : <1xi64, #shared0, mutable> tt.return } } // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: byval_tma_desc // CHECK: llvm.align = 64 // CHECK: llvm.byval = !llvm.array<128 x i8> @@ -95,7 +95,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- // CHECK-LABEL: device_tensormap_create1d -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @device_tensormap_create1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c256_i32 = arith.constant 256 : i32 %c1_i32 = arith.constant 1 : i32 @@ -120,7 +120,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: device_tensormap_create2d -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @device_tensormap_create2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c256_i32 = arith.constant 256 : i32 %c1_i32 = arith.constant 1 : i32 @@ -150,7 +150,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: tensormap_fenceproxy_acquire -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @tensormap_fenceproxy_acquire(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { // CHECK: fence.proxy.tensormap::generic.acquire.gpu [ $0 + 0 ], 0x80; tt.experimental_tensormap_fenceproxy_acquire %arg0 : !tt.ptr diff --git a/test/Tools/tensor_layout_print.mlir b/test/Tools/tensor_layout_print.mlir index 80c0195934..9f802d2e3b 100644 --- a/test/Tools/tensor_layout_print.mlir +++ b/test/Tools/tensor_layout_print.mlir @@ -2,19 +2,19 @@ // RUN: triton-tensor-layout -i %s -alias-names="mfma" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA -// RUN: triton-tensor-layout -l "#triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA +// RUN: triton-tensor-layout -l "#ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA // RUN: triton-tensor-layout -i %s -alias-names="mfma" -t "tensor<16x16xf16>" -use-hw-view | FileCheck %s --check-prefix=CHECK-HW -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> tt.func @print(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #blocked> %cst1 = arith.constant dense<0.00e+00> : tensor<16x16xf16, #mfma> tt.return } -// CHECK-BLOCKED: Print layout attribute: #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-BLOCKED: Print layout attribute: #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> // CHECK-BLOCKED: T0:0| T4:0, T0:1| T4:1, T0:2| T4:2, T0:3| T4:3, T1:0| T5:0, T1:1| T5:1, T1:2| T5:2, T1:3| T5:3, T2:0| T6:0, T2:1| T6:1, T2:2| T6:2, T2:3| T6:3, T3:0| T7:0, T3:1| T7:1, T3:2| T7:2, T3:3| T7:3 // CHECK-BLOCKED: T8:0| T12:0, T8:1| T12:1, T8:2| T12:2, T8:3| T12:3, T9:0| T13:0, T9:1| T13:1, T9:2| T13:2, T9:3| T13:3, T10:0| T14:0, T10:1| T14:1, T10:2| T14:2, T10:3| T14:3, T11:0| T15:0, T11:1| T15:1, T11:2| T15:2, T11:3| T15:3 // CHECK-BLOCKED: T16:0| T20:0, T16:1| T20:1, T16:2| T20:2, T16:3| T20:3, T17:0| T21:0, T17:1| T21:1, T17:2| T21:2, T17:3| T21:3, T18:0| T22:0, T18:1| T22:1, T18:2| T22:2, T18:3| T22:3, T19:0| T23:0, T19:1| T23:1, T19:2| T23:2, T19:3| T23:3 @@ -33,7 +33,7 @@ tt.func @print(%A : !tt.ptr) { // CHECK-BLOCKED: T120:0|T124:0, T120:1|T124:1, T120:2|T124:2, T120:3|T124:3, T121:0|T125:0, T121:1|T125:1, T121:2|T125:2, T121:3|T125:3, T122:0|T126:0, T122:1|T126:1, T122:2|T126:2, T122:3|T126:3, T123:0|T127:0, T123:1|T127:1, T123:2|T127:2, T123:3|T127:3 -// CHECK-MFMA: Print layout attribute: {{.*}}#triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +// CHECK-MFMA: Print layout attribute: {{.*}}#ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> // CHECK-MFMA: T0:0| T64:0|T128:0|T192:0, T0:1| T64:1|T128:1|T192:1, T0:2| T64:2|T128:2|T192:2, T0:3| T64:3|T128:3|T192:3, T16:0| T80:0|T144:0|T208:0, T16:1| T80:1|T144:1|T208:1, T16:2| T80:2|T144:2|T208:2, T16:3| T80:3|T144:3|T208:3, T32:0| T96:0|T160:0|T224:0, T32:1| T96:1|T160:1|T224:1, T32:2| T96:2|T160:2|T224:2, T32:3| T96:3|T160:3|T224:3, T48:0|T112:0|T176:0|T240:0, T48:1|T112:1|T176:1|T240:1, T48:2|T112:2|T176:2|T240:2, T48:3|T112:3|T176:3|T240:3 // CHECK-MFMA: T1:0| T65:0|T129:0|T193:0, T1:1| T65:1|T129:1|T193:1, T1:2| T65:2|T129:2|T193:2, T1:3| T65:3|T129:3|T193:3, T17:0| T81:0|T145:0|T209:0, T17:1| T81:1|T145:1|T209:1, T17:2| T81:2|T145:2|T209:2, T17:3| T81:3|T145:3|T209:3, T33:0| T97:0|T161:0|T225:0, T33:1| T97:1|T161:1|T225:1, T33:2| T97:2|T161:2|T225:2, T33:3| T97:3|T161:3|T225:3, T49:0|T113:0|T177:0|T241:0, T49:1|T113:1|T177:1|T241:1, T49:2|T113:2|T177:2|T241:2, T49:3|T113:3|T177:3|T241:3 // CHECK-MFMA: T2:0| T66:0|T130:0|T194:0, T2:1| T66:1|T130:1|T194:1, T2:2| T66:2|T130:2|T194:2, T2:3| T66:3|T130:3|T194:3, T18:0| T82:0|T146:0|T210:0, T18:1| T82:1|T146:1|T210:1, T18:2| T82:2|T146:2|T210:2, T18:3| T82:3|T146:3|T210:3, T34:0| T98:0|T162:0|T226:0, T34:1| T98:1|T162:1|T226:1, T34:2| T98:2|T162:2|T226:2, T34:3| T98:3|T162:3|T226:3, T50:0|T114:0|T178:0|T242:0, T50:1|T114:1|T178:1|T242:1, T50:2|T114:2|T178:2|T242:2, T50:3|T114:3|T178:3|T242:3 diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index d325ed2395..ef448d500e 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -64,11 +64,11 @@ tt.func @fold_advance(%arg: !tt.ptr>) -> (!tt.ptr -#sliced0 = #triton_gpu.slice<{dim = 1, parent = #blocked0}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#sliced0 = #ttg.slice<{dim = 1, parent = #blocked0}> // CHECK-LABEL: fn -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){ // CHECK: %[[a:.*]] = tt.expand_dims // CHECK: tt.broadcast %[[a]] @@ -80,8 +80,8 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){ // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @fp_to_fp_pos_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { // CHECK-LABEL: fp_to_fp_pos_zero_fold // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked> @@ -94,7 +94,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @fp_to_fp_pos_zero_fold_scalar() -> f8E4M3FNUZ { // CHECK-LABEL: fp_to_fp_pos_zero_fold_scalar // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant 0.000000e+00 : f8E4M3FNUZ @@ -107,8 +107,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FN, #blocked> { // CHECK-LABEL: fp_to_fp_neg_zero_fold // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<-0.000000e+00> : tensor<32x128xf8E4M3FN, #blocked> @@ -121,8 +121,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { // CHECK-LABEL: fp_to_fp_neg_zero_fold // We fold to the positive zero here given by definition f8E4M3FNUZ does not have negative zero encoding. @@ -136,8 +136,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @fold_fp_to_fp_non_zero_nofold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { // CHECK-LABEL: fold_fp_to_fp_non_zero_nofold // CHECK-NEXT: %[[cst:.+]] = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked> @@ -151,8 +151,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @fold_fp_to_fp_non_constant_nofold(%arg0: tensor<32x128xf32, #blocked>) -> tensor<32x128xf8E4M3FNUZ, #blocked> { // CHECK-LABEL: fold_fp_to_fp_non_constant_nofold // CHECK-NEXT: %[[arg_cvt:.+]] = tt.fp_to_fp %arg0 diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index c7fb41707e..07a7686651 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -157,8 +157,8 @@ tt.func public @fn(%v1: tensor<4x128xf32>, %v2: tensor<4x128xi64>) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { // expected-error @+2 {{op failed to infer returned types}} // expected-error @+1 {{incompatible with return type}} @@ -170,9 +170,9 @@ tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { // ----- // Bad order; should be [1,0] -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [0,1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [0,1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { // expected-error @+2 {{order}} // expected-error @+1 {{op failed to infer returned types}} @@ -215,11 +215,11 @@ tt.func public @fn(%arg0: tensor<2xf32>) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> +#blocked = #ttg.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> // Bad order, should be [1,0]. -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // expected-error @+2 {{op inferred type}} // expected-error @+1 {{op failed to infer returned types}} @@ -230,11 +230,11 @@ tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> +#blocked = #ttg.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> // bad sizePerThread; should be [1,1]. -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [0,1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [0,1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // expected-error @+2 {{op inferred type}} // expected-error @+1 {{op failed to infer returned types}} @@ -246,7 +246,7 @@ tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // ----- // Valid ops. -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32>) { %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32> -> tensor<16x32x64xf32> %b = tt.trans %arg0 {order = array} : tensor<16x32x64xf32> -> tensor<32x16x64xf32> @@ -257,11 +257,11 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32>) { // ----- // Valid op with blocked encoding. -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2,3,4], threadsPerWarp = [2,4,2,2], warpsPerCTA = [4,2,4,2], order = [3,2,1,0], CTAsPerCGA = [1,2,2,2], CTASplitNum = [1,2,4,8], CTAOrder = [3,2,1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2,4,3,1], threadsPerWarp = [4,2,2,2], warpsPerCTA = [2,2,4,4], order = [1,2,0,3], CTAsPerCGA = [2,2,2,1], CTASplitNum = [2,8,4,1], CTAOrder = [1,2,0,3]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [2,1,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1,2,3,4], threadsPerWarp = [2,4,2,2], warpsPerCTA = [4,2,4,2], order = [3,2,1,0], CTAsPerCGA = [1,2,2,2], CTASplitNum = [1,2,4,8], CTAOrder = [3,2,1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2,4,3,1], threadsPerWarp = [4,2,2,2], warpsPerCTA = [2,2,4,4], order = [1,2,0,3], CTAsPerCGA = [2,2,2,1], CTASplitNum = [2,8,4,1], CTAOrder = [1,2,0,3]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2,1,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<2x4x8x16xf32, #blocked>, %arg1: tensor<16x32x64xf32, #blocked2>) { %a = tt.trans %arg0 {order = array} : tensor<2x4x8x16xf32, #blocked> -> tensor<4x16x8x2xf32, #blocked1> %b = tt.trans %arg1 {order = array} : tensor<16x32x64xf32, #blocked2> -> tensor<32x16x64xf32, #blocked3> @@ -272,14 +272,14 @@ tt.func public @fn(%arg0: tensor<2x4x8x16xf32, #blocked>, %arg1: tensor<16x32x64 // ----- // Valid op with shared encoding. -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [3, 2, 1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0, 3]}> -#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 2], CTASplitNum = [2, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared3 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [4, 2], CTAOrder = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -tt.func public @fn(%arg0: !triton_gpu.memdesc<2x4x8x16xf32, #shared>, %arg1: !triton_gpu.memdesc<16x32xf32, #shared2>) { - %a = triton_gpu.memdesc_trans %arg0 {order = array} : !triton_gpu.memdesc<2x4x8x16xf32, #shared> -> !triton_gpu.memdesc<4x16x8x2xf32, #shared1> - %b = triton_gpu.memdesc_trans %arg1 {order = array} : !triton_gpu.memdesc<16x32xf32, #shared2> -> !triton_gpu.memdesc<32x16xf32, #shared3> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [3, 2, 1, 0]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0, 3]}> +#shared2 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 2], CTASplitNum = [2, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared3 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [4, 2], CTAOrder = [1, 0], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: !ttg.memdesc<2x4x8x16xf32, #shared>, %arg1: !ttg.memdesc<16x32xf32, #shared2>) { + %a = ttg.memdesc_trans %arg0 {order = array} : !ttg.memdesc<2x4x8x16xf32, #shared> -> !ttg.memdesc<4x16x8x2xf32, #shared1> + %b = ttg.memdesc_trans %arg1 {order = array} : !ttg.memdesc<16x32xf32, #shared2> -> !ttg.memdesc<32x16xf32, #shared3> tt.return } } // end module @@ -287,9 +287,9 @@ tt.func public @fn(%arg0: !triton_gpu.memdesc<2x4x8x16xf32, #shared>, %arg1: !tr // ----- // Invalid blocked encoding. -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32, #blocked>) { // expected-error @+1 {{type}} %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #blocked> -> tensor<32x16x64xf32, #blocked1> @@ -300,9 +300,9 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32, #blocked>) { // ----- // Invalid shared encoding. -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { // expected-error @+1 {{type}} %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #shared> -> tensor<32x16x64xf32, #shared1> @@ -312,7 +312,7 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { // ----- -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32xf32>) { // expected-error @+1 {{order}} %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> @@ -322,7 +322,7 @@ tt.func public @fn(%arg0: tensor<16x32xf32>) { // ----- -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32xf32>) { // expected-error @+1 {{order}} %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> @@ -332,7 +332,7 @@ tt.func public @fn(%arg0: tensor<16x32xf32>) { // ----- -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32xf32>) { // expected-error @+1 {{order must be a permutation}} %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> @@ -343,9 +343,9 @@ tt.func public @fn(%arg0: tensor<16x32xf32>) { // ----- // Invalid tensor with shared encoding. -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { // expected-error @+1 {{has an invalid layout: Shared layout is not allowed on tensor type.}} %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #shared> -> tensor<32x16x64xf32, #shared1> diff --git a/test/Triton/reproducer.mlir b/test/Triton/reproducer.mlir index f2c3a0f8e8..5a6747d217 100644 --- a/test/Triton/reproducer.mlir +++ b/test/Triton/reproducer.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt --verify-diagnostics --dump-pass-pipeline --run-reproducer %s 2>&1 | FileCheck %s -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @triton__() attributes {noinline = false} { tt.return } diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index 551c1f67b5..1f7de7d6d9 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -51,80 +51,80 @@ module { // %c256_i32 = arith.constant 256 : i32 // %0 = tt.get_program_id x : i32 // %1 = arith.muli %0, %c256_i32 : i32 -// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %3 = tt.broadcast %1 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %5 = tt.broadcast %arg3 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %6 = arith.cmpi "slt", %4, %5 : (tensor<256xi32, #triton_gpu<"coalesced encoding">>, tensor<256xi32, #triton_gpu<"coalesced encoding">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %7 = tt.broadcast %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %9 = tt.broadcast %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %11 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> +// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg<"coalesced encoding">> +// %3 = tt.broadcast %1 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %4 = arith.addi %3, %2 : tensor<256xi32, #ttg<"coalesced encoding">> +// %5 = tt.broadcast %arg3 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %6 = arith.cmpi "slt", %4, %5 : (tensor<256xi32, #ttg<"coalesced encoding">>, tensor<256xi32, #ttg<"coalesced encoding">>) -> tensor<256xi1, #ttg<"coalesced encoding">> +// %7 = tt.broadcast %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %9 = tt.broadcast %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %11 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> // %12 = arith.index_cast %arg4 : i32 to index // %13 = arith.cmpi slt, %c0, %12 : index -// %14 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %15 = tt.broadcast %13 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %16 = arith.andi %6, %15 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %17 = triton_gpu.copy_async %8, %16, %14 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %18 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %19 = tt.broadcast %13 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %21 = triton_gpu.copy_async %10, %20, %18 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %22 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %24 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> +// %14 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %15 = tt.broadcast %13 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %16 = arith.andi %6, %15 : tensor<256xi1, #ttg<"coalesced encoding">> +// %17 = ttg.copy_async %8, %16, %14 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %18 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %19 = tt.broadcast %13 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %20 = arith.andi %6, %19 : tensor<256xi1, #ttg<"coalesced encoding">> +// %21 = ttg.copy_async %10, %20, %18 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %22 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %24 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> // %26 = arith.cmpi slt, %c32, %12 : index -// %27 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %28 = tt.broadcast %26 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %29 = arith.andi %6, %28 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %30 = triton_gpu.copy_async %23, %29, %27 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %31 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %32 = tt.broadcast %26 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %34 = triton_gpu.copy_async %25, %33, %31 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %35 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %37 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> +// %27 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %28 = tt.broadcast %26 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %29 = arith.andi %6, %28 : tensor<256xi1, #ttg<"coalesced encoding">> +// %30 = ttg.copy_async %23, %29, %27 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %31 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %32 = tt.broadcast %26 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %33 = arith.andi %6, %32 : tensor<256xi1, #ttg<"coalesced encoding">> +// %34 = ttg.copy_async %25, %33, %31 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %35 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %37 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> // %39 = arith.cmpi slt, %c64, %12 : index -// %40 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %41 = tt.broadcast %39 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %42 = arith.andi %6, %41 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %43 = triton_gpu.copy_async %36, %42, %40 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %44 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %45 = tt.broadcast %39 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %47 = triton_gpu.copy_async %38, %46, %44 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %48 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %50 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { -// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %57 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %59 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> +// %40 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %41 = tt.broadcast %39 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %42 = arith.andi %6, %41 : tensor<256xi1, #ttg<"coalesced encoding">> +// %43 = ttg.copy_async %36, %42, %40 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %44 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %45 = tt.broadcast %39 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %46 = arith.andi %6, %45 : tensor<256xi1, #ttg<"coalesced encoding">> +// %47 = ttg.copy_async %38, %46, %44 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %48 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %50 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, index) { +// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #ttg<"coalesced encoding">> +// %56 = arith.addf %arg7, %55 : tensor<256xf32, #ttg<"coalesced encoding">> +// %57 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %59 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> // %61 = arith.addi %arg18, %c32 : index // %62 = arith.cmpi slt, %61, %12 : index -// %63 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %64 = tt.broadcast %62 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %65 = arith.andi %64, %6 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %66 = triton_gpu.copy_async %arg17, %65, %63 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %67 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %68 = triton_gpu.copy_async %arg16, %65, %67 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %69 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %71 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index +// %63 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %64 = tt.broadcast %62 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %65 = arith.andi %64, %6 : tensor<256xi1, #ttg<"coalesced encoding">> +// %66 = ttg.copy_async %arg17, %65, %63 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %67 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %68 = ttg.copy_async %arg16, %65, %67 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %69 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %71 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, index // } -// %53 = tt.broadcast %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding">> +// %53 = tt.broadcast %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// tt.store %54, %52#0, %6 : tensor<256xf32, #ttg<"coalesced encoding">> // tt.return // } // } diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 1d564ff9db..15704bd0c2 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -1,20 +1,20 @@ // RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul | FileCheck %s -// CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -// CHECK: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -// CHECK: #[[MMA2:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +// CHECK: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +// CHECK: #[[MMA2:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: mma_chain_loop tt.func public @mma_chain_loop( - %170: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %171: tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, - %179: tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>, - %164: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>, - %165: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>, - %173: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>, + %170: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %171: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %179: tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>, + %164: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>, + %165: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>>, + %173: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>, %153: tensor<128x64x!tt.ptr, #blocked1>) { %c0_i32 = arith.constant 0 : i32 %c8_i32 = arith.constant 8 : i32 @@ -23,21 +23,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2> // CHECK: scf.for - // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> - // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { - %172 = tt.dot %170, %171, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> - %178 = triton_gpu.convert_layout %172 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - %180 = tt.dot %178, %179, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + %172 = tt.dot %170, %171, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> + %178 = ttg.convert_layout %172 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %180 = tt.dot %178, %179, %arg16 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> scf.yield %180 : tensor<128x64xf16, #blocked1> } // CHECK: scf.for - // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> - // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { - %166 = tt.dot %164, %165, %cst_2 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> - %172 = triton_gpu.convert_layout %166 : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - %174 = tt.dot %172, %173, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + %166 = tt.dot %164, %165, %cst_2 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> + %172 = ttg.convert_layout %166 : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %174 = tt.dot %172, %173, %arg16 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> scf.yield %174 : tensor<128x64xf16, #blocked1> } tt.store %153, %149 : tensor<128x64x!tt.ptr, #blocked1> @@ -47,79 +47,79 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: chained_dot tt.func public @chained_dot( - %arg0: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %arg1: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, - %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { + %arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> // CHECK: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> %d = tt.dot %arg0, %arg1, %cst_0 : - tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> - %c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + %c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> // CHECK: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> %r = tt.dot %c, %arg2, %cst_1 : - tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> + tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> tt.return %r : tensor<64x128xf32, #blocked1> } } // ----- -// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:89", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: fp8_dot tt.func public @fp8_dot( - %arg0: tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %arg1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, - %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x64xf32, #blocked> { + %arg0: tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x64xf32, #blocked> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> - // CHECK: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> + // CHECK: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> %d = tt.dot %arg0, %arg1, %cst_0 : - tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> tt.return %d : tensor<64x64xf32, #blocked> } } // ----- -// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -// CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> +// CHECK-DAG: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +// CHECK-DAG: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [0, 1, 2]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 4, 8], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [0, 1, 2]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 4, 8], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: kernel_ tt.func public @kernel_() attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<2x16x16xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> - %0 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - %1 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> - %2 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %1 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> + %2 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #blocked1> // CHECK: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]> - %3 = tt.dot %0, %1, %2, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<16x16xf32, #blocked1> - %4 = triton_gpu.convert_layout %3 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16x16xf32, #blocked2> - %6 = triton_gpu.convert_layout %5 : tensor<1x16x16xf32, #blocked2> -> tensor<1x16x16xf32, #blocked> + %3 = tt.dot %0, %1, %2, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<16x16xf32, #blocked1> + %4 = ttg.convert_layout %3 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16x16xf32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16x16xf32, #blocked2> + %6 = ttg.convert_layout %5 : tensor<1x16x16xf32, #blocked2> -> tensor<1x16x16xf32, #blocked> %7 = tt.broadcast %6 : tensor<1x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked> - %8 = triton_gpu.convert_layout %7 : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> - %9 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> - %10 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked3> + %8 = ttg.convert_layout %7 : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> + %9 = ttg.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> + %10 = ttg.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked3> // CHECK: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]> - %11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3> - %12 = triton_gpu.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked> + %11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3> + %12 = ttg.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked> tt.print ": " {hex = false, isSigned = array} : %12 : tensor<2x16x16xf32, #blocked> tt.return } @@ -127,17 +127,17 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, {{.*}}, instrShape = [16, 32, 16]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 3, {{.*}}, instrShape = [16, 32, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: check_instrShape_per_warps tt.func @check_instrShape_per_warps(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { %mask = arith.constant dense : tensor<128x128xi1, #blocked> %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %a = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %b = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %a = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %b = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %result = tt.dot %a, %b, %zero_f32 : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> %result_ptr = tt.splat %arg0 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> tt.store %result_ptr, %result, %mask : tensor<128x128x!tt.ptr, #blocked> tt.return @@ -148,16 +148,16 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // Verify that we use mmav2 when the k dim is too small for mmav3. -// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: small_k_size tt.func @small_k_size( - %a: tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %b: tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) + %a: tensor<128x16xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %b: tensor<16x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<128x128xf32, #blocked> { %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %result = tt.dot %a, %b, %zero_f32 : tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<128x16xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } } @@ -165,20 +165,20 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // Verify that dot_scaled (mxfp4 x {bf16,fp8}) decomposes to mmav3 if it's bf16, otherwise it fallsback to mmav2 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -// CHECK: #[[LINEAR:.+]] = #triton_gpu.linear<{{.*}}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[LINEAR:.+]] = #ttg.linear<{{.*}}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: dot_scaled tt.func @dot_scaled( %a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b_bf16: tensor<64x128xbf16, #blocked> ) -> tensor<128x128xf32, #blocked> { - // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, {{.*}}> - // CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>> - // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK: ttg.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, {{.*}}> + // CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>> + // CHECK: ttng.warp_group_dot %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> @@ -192,8 +192,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %b_fp8: tensor<64x128xf8E4M3FN, #blocked> ) -> tensor<128x128xf32, #blocked> { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]> - // CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>> + // CHECK: ttg.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]> + // CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>> // CHECK: tt.dot %result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> @@ -202,11 +202,11 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_scale_transpose tt.func public @dot_scale_transpose(%arg0: tensor<128x64xf8E4M3FN, #blocked>, %arg1: tensor<32x32xi8, #blocked1>, %arg2: tensor<32x2xi8, #blocked2>, %arg3: tensor<128x32x!tt.ptr, #blocked3>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked1> @@ -226,10 +226,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : scf.yield %3 : tensor<128x32xf32, #blocked1> } // CHECK: arith.truncf - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.trans %1 = arith.truncf %0 : tensor<128x32xf32, #blocked1> to tensor<128x32xbf16, #blocked1> - %2 = triton_gpu.convert_layout %1 : tensor<128x32xbf16, #blocked1> -> tensor<128x32xbf16, #blocked3> + %2 = ttg.convert_layout %1 : tensor<128x32xbf16, #blocked1> -> tensor<128x32xbf16, #blocked3> tt.store %arg3, %2 : tensor<128x32x!tt.ptr, #blocked3> tt.return } diff --git a/test/TritonGPU/accumulator-init.mlir b/test/TritonGPU/accumulator-init.mlir index e73934818a..c5302913c4 100644 --- a/test/TritonGPU/accumulator-init.mlir +++ b/test/TritonGPU/accumulator-init.mlir @@ -1,24 +1,24 @@ // RUN: triton-opt %s -split-input-file -tritongpu-optimize-accumulator-init | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @constant_init // CHECK-DAG: %[[FALSE:.+]] = arith.constant false -// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] - tt.func @constant_init(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK: ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %cst_2 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -26,15 +26,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: @constant_init_integer // CHECK-DAG: %[[FALSE:.+]] = arith.constant false -// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] - tt.func @constant_init_integer(%A: !triton_gpu.memdesc<128x64xi8, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xi8, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xi32, #mma1> { +// CHECK: ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init_integer(%A: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xi8, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xi32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0> : tensor<128x16xi32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xi32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %cst_2 : !triton_gpu.memdesc<128x64xi8, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xi8, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xi32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xi8, #shared1, #ttg.shared_memory> -> tensor<128x16xi32, #mma1> scf.yield %acc: tensor<128x16xi32, #mma1> } tt.return %17 : tensor<128x16xi32, #mma1> @@ -46,21 +46,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) // CHECK: %[[CND:.+]] = arith.cmpi -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] // CHECK: scf.if %[[CND]] // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @if_after_mma(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -77,21 +77,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) // CHECK: %[[CND:.+]] = arith.cmpi -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[TRUE]], %[[FALSE]] // CHECK: scf.if %[[CND]] // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @if_after_mma_invert(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_after_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %acc : tensor<128x16xf32, #mma1> } else { @@ -113,9 +113,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.yield %[[ACC]] // CHECK: else // CHECK: scf.yield %[[ACC]] -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_mma(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -127,7 +127,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } else { scf.yield %arg4 : tensor<128x16xf32, #mma1> } - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -144,9 +144,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.yield %[[ACC]] // CHECK: else // CHECK: scf.yield %[[ACC]] -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_mma_invert(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -158,7 +158,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } else { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -170,17 +170,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) // CHECK: %[[CND:.+]] = arith.cmpi -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @sel_after_mma(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @sel_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } @@ -194,9 +194,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) // CHECK: %[[CND:.+]] = arith.cmpi // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @sel_before_mma(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @sel_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -204,7 +204,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xf32, #mma1> - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -224,13 +224,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.yield %[[ACC]] // CHECK: else // CHECK: scf.yield %[[ACC]] -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] // CHECK: scf.if %[[CND]] // CHECK: scf.yield %[[C0_TENSOR]] // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_and_after_mma(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_and_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -242,7 +242,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } else { scf.yield %arg4 : tensor<128x16xf32, #mma1> } - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_0 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_0 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -259,7 +259,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) // CHECK: %[[CND:.+]] = arith.cmpi -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] // CHECK: scf.yield %[[C0_TENSOR]] // CHECK: else @@ -270,14 +270,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: else // CHECK: scf.yield %[[ACC_CND]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @two_ifs_after_mma(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @two_ifs_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -296,15 +296,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that we bail out in unsupported cases // CHECK-LABEL: @non_zero_init -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !triton_gpu.memdesc - tt.func @non_zero_init(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @non_zero_init(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } @@ -312,15 +312,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: @zero_init_dist_2 -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !triton_gpu.memdesc - tt.func @zero_init_dist_2(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @zero_init_dist_2(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %cst_2) -> (tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg5 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg5 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_, %arg4: tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } @@ -328,8 +328,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: @if_defines_alternative -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !triton_gpu.memdesc - tt.func @if_defines_alternative(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @if_defines_alternative(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> @@ -337,7 +337,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -350,15 +350,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: @non_cond_override -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !triton_gpu.memdesc - tt.func @non_cond_override(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @non_cond_override(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } @@ -367,15 +367,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // If the condition is a tensor skip the optimization. // CHECK-LABEL: @negative_sel_tensor -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !triton_gpu.memdesc - tt.func @negative_sel_tensor(%A: !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @negative_sel_tensor(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir index 7854a4eed7..260dddb954 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir @@ -1,19 +1,19 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx940 matrix-instruction-size=0' | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> // CHECK-LABEL: mfma_dot_fp8e5m2 -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @mfma_dot_fp8e5m2( - %arg0: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %arg1: tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg0: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<128x256x!tt.ptr, #blocked> ) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> - // CHECK: %[[A0:.+]] = triton_gpu.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - // CHECK: %[[B0:.+]] = triton_gpu.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: %[[A0:.+]] = ttg.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[B0:.+]] = ttg.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> // CHECK: tt.dot %[[A1]], %[[B1]] - %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> tt.store %arg2, %1 : tensor<128x256x!tt.ptr, #blocked> tt.return } diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir index 7d3e8c23be..b68fe93493 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir @@ -1,27 +1,27 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx1100 matrix-instruction-size=0' | FileCheck %s -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf32( - // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<128x256x!tt.ptr, #blocked>) { // CHECK: %[[DOT0_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT0_OP_C:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_C]] + // CHECK: %[[DOT0_OP_C:.+]] = ttg.convert_layout %[[DOT0_ARG_C]] // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] %3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> - // CHECK: %[[DOT0_OP_A:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_A]] - // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] - // CHECK: %[[DOT0_OP_B:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_B]] - // CHECK-SAME: -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_A:.+]] = ttg.convert_layout %[[DOT0_ARG_A]] + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_B:.+]] = ttg.convert_layout %[[DOT0_ARG_B]] + // CHECK-SAME: -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] // CHECK: %[[DOT0_WMMA_RES:.+]] = tt.dot %[[DOT0_OP_A]], %[[DOT0_OP_B]], %[[DOT0_OP_C]] // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] - %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT0_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + // CHECK: ttg.convert_layout %[[DOT0_WMMA_RES]] // CHECK-SAME: -> tensor<128x256xf32, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<128x256x!tt.ptr, #blocked> tt.return @@ -30,28 +30,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf16( - // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] %3 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> - // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] - // CHECK-SAME: -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] - // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] - // CHECK-SAME: -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] - %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] // CHECK-SAME: -> tensor<32x32xf16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return @@ -60,32 +60,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_ab8_cf16( - // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x64x!tt.ptr, #blocked>) { // CHECK: %[[DOT2_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT2_OP_C:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_C]] + // CHECK: %[[DOT2_OP_C:.+]] = ttg.convert_layout %[[DOT2_ARG_C]] // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] %3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #blocked> - // CHECK: %[[DOT2_OP_A_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_A]] - // CHECK-SAME: -> tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_A_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_A]] + // CHECK-SAME: -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] // CHECK: %[[DOT2_OP_A_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_A_F8]] - // CHECK-SAME: -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 16}>> - // CHECK: %[[DOT2_OP_B_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_B]] - // CHECK-SAME: -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK-SAME: -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 16}>> + // CHECK: %[[DOT2_OP_B_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_B]] + // CHECK-SAME: -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] // CHECK: %[[DOT2_OP_B_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_B_F8]] - // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 16}>> + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 16}>> // CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A_F16]], %[[DOT2_OP_B_F16]], %[[DOT2_OP_C]] // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] - %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT2_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT2_WMMA_RES]] // CHECK-SAME: -> tensor<32x64xf16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x64x!tt.ptr, #blocked> tt.return @@ -94,28 +94,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_i8_i32( - // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0> : tensor<32x32xi32, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] %3 = arith.constant dense<0> : tensor<32x32xi32, #blocked> - // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] - // CHECK-SAME: -> tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] - // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] - // CHECK-SAME: -> tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] - %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] // CHECK-SAME: -> tensor<32x32xi32, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return @@ -124,26 +124,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fma_dot_i16_i16( - // CHECK: %[[DOT3_ARG_A:.+]]: tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT3_ARG_B:.+]]: tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT3_ARG_A:.+]]: tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT3_ARG_B:.+]]: tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<128x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT3_ARG_C:.+]] = arith.constant dense<0> : tensor<128x32xi16, #[[DOT_OP_PARENT]]> %3 = arith.constant dense<0> : tensor<128x32xi16, #blocked> // CHECK: %[[DOT3_OP_A:.+]] = arith.sitofp %[[DOT3_ARG_A]] - // CHECK-SAME: to tensor<128x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]] + // CHECK-SAME: to tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]] // CHECK: %[[DOT3_OP_B:.+]] = arith.sitofp %[[DOT3_ARG_B]] - // CHECK-SAME: to tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]] + // CHECK-SAME: to tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]] // CHECK: %[[DOT3_OP_C:.+]] = arith.sitofp %[[DOT3_ARG_C]] // CHECK-SAME: to tensor<128x32xf32, #[[DOT_OP_PARENT]] // CHECK: %[[DOT3_FMA_RES:.+]] = tt.dot %[[DOT3_OP_A]], %[[DOT3_OP_B]], %[[DOT3_OP_C]] // CHECK-SAME: -> tensor<128x32xf32, #[[DOT_OP_PARENT]]> - %4 = tt.dot %0, %1, %3 : tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xi16, #blocked> + %4 = tt.dot %0, %1, %3 : tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xi16, #blocked> // CHECK: arith.fptosi %[[DOT3_FMA_RES]] // CHECK-SAME: to tensor<128x32xi16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<128x32x!tt.ptr, #blocked> diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir index a8683a5d39..a5bf857dfb 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir @@ -1,27 +1,27 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx1200 matrix-instruction-size=0' | FileCheck %s -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf32( - // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<128x256x!tt.ptr, #blocked>) { // CHECK: %[[DOT0_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT0_OP_C:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_C]] + // CHECK: %[[DOT0_OP_C:.+]] = ttg.convert_layout %[[DOT0_ARG_C]] // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] %3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> - // CHECK: %[[DOT0_OP_A:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_A]] - // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] - // CHECK: %[[DOT0_OP_B:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_B]] - // CHECK-SAME: -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_A:.+]] = ttg.convert_layout %[[DOT0_ARG_A]] + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_B:.+]] = ttg.convert_layout %[[DOT0_ARG_B]] + // CHECK-SAME: -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] // CHECK: %[[DOT0_WMMA_RES:.+]] = tt.dot %[[DOT0_OP_A]], %[[DOT0_OP_B]], %[[DOT0_OP_C]] // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] - %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT0_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + // CHECK: ttg.convert_layout %[[DOT0_WMMA_RES]] // CHECK-SAME: -> tensor<128x256xf32, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<128x256x!tt.ptr, #blocked> tt.return @@ -30,28 +30,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf16( - // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] %3 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> - // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] - // CHECK-SAME: -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] - // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] - // CHECK-SAME: -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] - %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] // CHECK-SAME: -> tensor<32x32xf16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return @@ -60,32 +60,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_ab8_cf16( - // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x64x!tt.ptr, #blocked>) { // CHECK: %[[DOT2_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT2_OP_C:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_C]] + // CHECK: %[[DOT2_OP_C:.+]] = ttg.convert_layout %[[DOT2_ARG_C]] // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] %3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #blocked> - // CHECK: %[[DOT2_OP_A_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_A]] - // CHECK-SAME: -> tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_A_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_A]] + // CHECK-SAME: -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] // CHECK: %[[DOT2_OP_A_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_A_F8]] - // CHECK-SAME: -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 8}>> - // CHECK: %[[DOT2_OP_B_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_B]] - // CHECK-SAME: -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK-SAME: -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 8}>> + // CHECK: %[[DOT2_OP_B_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_B]] + // CHECK-SAME: -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] // CHECK: %[[DOT2_OP_B_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_B_F8]] - // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 8}>> + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 8}>> // CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A_F16]], %[[DOT2_OP_B_F16]], %[[DOT2_OP_C]] // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] - %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT2_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT2_WMMA_RES]] // CHECK-SAME: -> tensor<32x64xf16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x64x!tt.ptr, #blocked> tt.return @@ -94,28 +94,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_i8_i32( - // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0> : tensor<32x32xi32, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] %3 = arith.constant dense<0> : tensor<32x32xi32, #blocked> - // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] - // CHECK-SAME: -> tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] - // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] - // CHECK-SAME: -> tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] - %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] // CHECK-SAME: -> tensor<32x32xi32, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index 6c3e2ac42f..ed47e1512d 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @conversion1 tt.func @conversion1(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -22,8 +22,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @conversion2 tt.func @conversion2(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -50,8 +50,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @conversion3 tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -89,8 +89,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // // This is the same as conversion3, but now the `arith.extsi` operations // disappeared and all the offsets are 32 bits. @@ -129,8 +129,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @forOp tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -175,8 +175,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @forOp2 tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -221,8 +221,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @forNested tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -267,8 +267,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @ifOp tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -309,8 +309,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @whileOp tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -345,8 +345,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @condBranch tt.func @condBranch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -389,8 +389,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @branch tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -428,30 +428,30 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So // we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform // offset will be A*B+D -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func @tile_offset tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32 ) { %c128_i32 = arith.constant 128 : i32 %c256_i32 = arith.constant 256 : i32 %1 = tt.get_program_id x : i32 %20 = arith.muli %1, %c256_i32 : i32 - %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %24 = tt.splat %20 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %26 = arith.addi %24, %22 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %20 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> %38 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked> %39 = arith.muli %37, %38 : tensor<16x1xi32, #blocked> - %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> %42 = tt.broadcast %39 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> %43 = tt.broadcast %41 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> %44 = arith.addi %42, %43 : tensor<16x256xi32, #blocked> %45 = tt.splat %arg1 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> %46 = tt.addptr %45, %44 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> // CHECK: %[[uniformOffset1:.*]] = arith.muli %c0_i32_0, %arg2 : i32 - // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> // CHECK: {{.*}} = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %{{.*}} : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> @@ -483,21 +483,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // = (U + N)*U + N // Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range) // The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8) -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func public @matmul_kernel tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { %c128_i32 = arith.constant 128 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c128_i32 : i32 - %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %3 = tt.splat %1 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %4 = arith.addi %3, %2 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = arith.addi %3, %2 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked> - %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> %10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> %11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> %12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked> @@ -509,8 +509,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[uniformOffset2:.*]] = arith.addi %[[uniformOffset3]], %{{.*}} : i32 // CHECK: %[[uniformOffset1:.*]] = arith.muli %[[uniformOffset2]], %arg1 : i32 // CHECK: %[[makerange:.*]] = tt.make_range - // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> - // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> @@ -530,8 +530,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @select tt.func @select(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -563,8 +563,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1100", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func @where_kernel tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}){ %c0_i8 = arith.constant 0 : i8 @@ -589,8 +589,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @forOpWithHints tt.func @forOpWithHints(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c0 = arith.constant 0: index @@ -620,8 +620,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: scalar_pointers tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %0 = tt.get_program_id x : i32 @@ -648,8 +648,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: @scalar_if tt.func @scalar_if(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)->f32{ %0 = tt.get_program_id x : i32 @@ -678,8 +678,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @scalar_while tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1)->f32{ %c1024_i32 = arith.constant 1024 : i32 @@ -707,8 +707,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @scalar_cond_branch tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %i1 : i1) -> f32{ %c1024_i32 = arith.constant 1024 : i32 diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 25897f2a93..18922b15aa 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops | FileCheck %s -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: simple tt.func @simple(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 :i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -31,8 +31,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: assume_positive_offset tt.func @assume_positive_offset(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -59,8 +59,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: offset_64_bits tt.func @offset_64_bits(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> { %c1024_i32 = arith.constant 1024 : i32 @@ -84,8 +84,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: offset_64_bits_narrow tt.func public @offset_64_bits_narrow(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> { %c1024_i32 = arith.constant 1024 : i32 @@ -111,8 +111,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: non_canonical_ptr tt.func @non_canonical_ptr(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: tensor<1024xi32, #blocked>) -> tensor<1024xf32, #blocked>{ %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> diff --git a/test/TritonGPU/amd/amd-extractslice-op.mlir b/test/TritonGPU/amd/amd-extractslice-op.mlir index ef47a9f9b4..bde77b475e 100644 --- a/test/TritonGPU/amd/amd-extractslice-op.mlir +++ b/test/TritonGPU/amd/amd-extractslice-op.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func @basic_insert_slice(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { // CHECK: llvm.func @basic_insert_slice // CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index b9f40dc291..8cc3ae64f4 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -74,15 +74,15 @@ module { // LABELING_PS_1: scf.for // LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} // LABELING_PS_1: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} - // LABELING_PS_1: %[[REG1_OP0:.+]] = triton_gpu.convert_layout %[[REG0_OP0]] - // LABELING_PS_1: %[[REG1_OP1:.+]] = triton_gpu.convert_layout %[[REG0_OP1]] + // LABELING_PS_1: %[[REG1_OP0:.+]] = ttg.convert_layout %[[REG0_OP0]] + // LABELING_PS_1: %[[REG1_OP1:.+]] = ttg.convert_layout %[[REG0_OP1]] // LABELING_PS_1: tt.dot %[[REG1_OP0]], %[[REG1_OP1]], {{.*}} // LABELING_PS_2: scf.for // LABELING_PS_2: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} // LABELING_PS_2: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} - // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} - // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_2: ttg.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: ttg.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<1>} %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr>, tensor<128x32xi32> diff --git a/test/TritonGPU/amd/amd-optimize-epilogue.mlir b/test/TritonGPU/amd/amd-optimize-epilogue.mlir index 8939562d0c..8cc467e773 100644 --- a/test/TritonGPU/amd/amd-optimize-epilogue.mlir +++ b/test/TritonGPU/amd/amd-optimize-epilogue.mlir @@ -1,17 +1,17 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-optimize-epilogue | FileCheck %s // CHECK-LABEL: one_op_in_chain -// CHECK-NOT: triton_gpu.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> +// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> // CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x32x!tt.ptr, #mma> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @one_op_in_chain(%arg0: !tt.ptr) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %2 = arith.truncf %1 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> %3 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> tt.store %3, %2 : tensor<32x32x!tt.ptr, #blocked> @@ -22,17 +22,17 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- // CHECK-LABEL: two_ops_in_chain -// CHECK-NOT: triton_gpu.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> +// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> // CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x32x!tt.ptr, #mma> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @two_ops_in_chain(%arg0: !tt.ptr) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %2 = math.exp2 %1 : tensor<32x32xf32, #blocked> %3 = arith.truncf %2 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 708d75a232..8fa6d6fe12 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -7,14 +7,14 @@ // CHECK-LABEL: hoist_q_out_of_the_loop // CHECK: %[[TRUNCF:.+]] = arith.truncf -// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]] -// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] +// CHECK-NEXT: %[[ALLOC:.+]] = ttg.local_alloc %[[TRUNCF]] +// CHECK-NEXT: ttg.local_load %[[ALLOC]] // CHECK: scf.for -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant 1.44269502 : f32 @@ -34,11 +34,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> - %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> - %76 = triton_gpu.local_load %75 : !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !triton_gpu.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> - %78 = triton_gpu.local_load %77 : !triton_gpu.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory> + %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #ttg.shared_memory> + %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> %107 = arith.addi %arg26, %c128_i64 : i64 scf.yield %107 : i64 } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} @@ -54,11 +54,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: scf.for // CHECK: %[[TRUNCF:.+]] = arith.truncf // CHECK-NEXT: arith.constant -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant 1.44269502 : f32 @@ -78,11 +78,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> - %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> - %76 = triton_gpu.local_load %75 : !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !triton_gpu.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> - %78 = triton_gpu.local_load %77 : !triton_gpu.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory> + %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #ttg.shared_memory> + %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> %107 = arith.addi %arg26, %c128_i64 : i64 scf.yield %107 : i64 } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} @@ -91,25 +91,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> // CHECK-LABEL: order_load_alloc_local_load_local_store // CHECK: %[[LOAD:.+]] = tt.load -// CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc -// CHECK: triton_gpu.local_store %[[LOAD]], %[[ALLOC]] -// CHECK: triton_gpu.local_load %[[ALLOC]] -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: %[[ALLOC:.+]] = ttg.local_alloc +// CHECK: ttg.local_store %[[LOAD]], %[[ALLOC]] +// CHECK: ttg.local_load %[[ALLOC]] +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @order_load_alloc_local_load_local_store(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %10 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<32x32xf32, #shared, mutable> - triton_gpu.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !triton_gpu.memdesc<32x32xf32, #shared, mutable> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %11 = triton_gpu.local_load %10 : !triton_gpu.memdesc<32x32xf32, #shared, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %10 = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #shared, mutable> + ttg.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !ttg.memdesc<32x32xf32, #shared, mutable> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared, mutable> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> tt.return } @@ -167,15 +167,15 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // yield // } -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> -#shared3 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared4 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared2 = #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +#shared3 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared4 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32, ttg.target = "hip:gfx942"} { // CHECK-LABEL: tt.func @matmul_loop // CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) @@ -189,57 +189,57 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] // CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]] // Stage 1 -// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG10]] -// CHECK: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_28:.*]] = ttg.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_29:.*]] = ttg.local_load %[[ARG11]] // CHECK: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} // CHECK: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %[[ARG8]] // Stage 0.b // CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG9]], %{{.*}} // CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}} // CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] -// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] +// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] // CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] // CHECK: } tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> - %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %10 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %11 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #ttg.shared_memory, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #ttg.shared_memory, mutable> %12 = arith.cmpi slt, %arg0, %arg1 : index %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %17 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %18 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable>) { %20 = arith.subi %arg1, %arg2 : index %21 = arith.cmpi slt, %arg5, %20 : index - %22 = triton_gpu.local_load %arg10 : !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %23 = triton_gpu.local_load %arg11 : !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %24 = arith.mulf %23, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %22 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %23 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = arith.mulf %23, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> %26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %27 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %28 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1> @@ -249,14 +249,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %32 = arith.addi %arg9, %c1_i32 : i32 %33 = arith.cmpi slt, %32, %c1_i32 : i32 %34 = arith.select %33, %32, %c0_i32 : i32 - %35 = triton_gpu.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %36 = triton_gpu.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %35 = ttg.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + %36 = ttg.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> } - triton_gpu.local_dealloc %10 : !triton_gpu.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %11 : !triton_gpu.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #ttg.shared_memory, mutable> tt.return %19#2 : tensor<128x128xf32, #mma> } @@ -281,13 +281,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}} // CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} // CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] -// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] // Stage 2 -// CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]] -// CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_42:.*]] = ttg.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_43:.*]] = ttg.local_load %[[ARG11]] // CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} // CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] // CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]] @@ -298,23 +298,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> - %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %10 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %11 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %10 = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<2x32x128xf16, #shared1, #ttg.shared_memory, mutable> %12 = arith.cmpi slt, %arg0, %arg1 : index %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> @@ -328,18 +328,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %22 = tt.load %19, %21 : tensor<128x32x!tt.ptr, #blocked1> %23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> %24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %25 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !triton_gpu.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %26 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !triton_gpu.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { + %25 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + %26 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { %28 = arith.muli %arg2, %c2 : index %29 = arith.subi %arg1, %28 : index %30 = arith.cmpi slt, %arg5, %29 : index - %31 = triton_gpu.local_load %arg10 : !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %32 = triton_gpu.local_load %arg11 : !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %33 = arith.mulf %32, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %31 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %33 = arith.mulf %32, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> %35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %36 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %37 = tt.splat %30 : i1 -> tensor<128x32xi1, #blocked1> @@ -349,14 +349,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %41 = arith.addi %arg9, %c1_i32 : i32 %42 = arith.cmpi slt, %41, %c2_i32 : i32 %43 = arith.select %42, %41, %c0_i32 : i32 - %44 = triton_gpu.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !triton_gpu.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %45 = triton_gpu.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !triton_gpu.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> + %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> } - triton_gpu.local_dealloc %10 : !triton_gpu.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %11 : !triton_gpu.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %10 : !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<2x32x128xf16, #shared1, #ttg.shared_memory, mutable> tt.return %27#2 : tensor<128x128xf32, #mma> } @@ -382,80 +382,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]] // CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]] // Stage 2 -// CHECK: %[[LOCAL_LOAD_36:.*]] = triton_gpu.local_load %[[ARG11]] -// CHECK: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_36:.*]] = ttg.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_37:.*]] = ttg.local_load %[[ARG12]] // CHECK: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_36]], %[[LOCAL_LOAD_37]], %[[ARG7]] // Stage 1.b // CHECK: %[[ADDI_39:.*]] = arith.addi %[[ARG10]], %{{.*}} // CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}} // CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] -// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] +// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] +// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] // CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]] // CHECK: } - tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { %c2 = arith.constant 2 : index %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c1_i32 = arith.constant 1 : i32 - %cst_0 = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %cst_0 = arith.constant dense<1> : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> %2 = arith.cmpi sgt, %arg1, %c0 : index - %3 = tt.splat %2 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %3 = tt.splat %2 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> %5 = arith.cmpi sgt, %arg1, %c1 : index - %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %7 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked1> %8 = tt.load %arg2, %7 : tensor<16x16x!tt.ptr, #blocked1> - %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> %10 = tt.broadcast %9 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> %11 = arith.muli %arg0, %10 : tensor<16x16xi64, #blocked> %12 = tt.addptr %arg5, %11 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> %13 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked> %14 = tt.load %12, %13 : tensor<16x16x!tt.ptr, #blocked> - %15 = tt.splat %5 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %17 = triton_gpu.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %18 = triton_gpu.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %15 = tt.splat %5 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = ttg.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> + ttg.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> + %18 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> + ttg.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> + %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>) { %20 = arith.subi %arg1, %c2 : index %21 = arith.cmpi slt, %arg6, %20 : index %22 = arith.subi %arg1, %c1 : index %23 = arith.cmpi slt, %arg6, %22 : index - %24 = triton_gpu.local_load %arg11 : !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %25 = triton_gpu.local_load %arg12 : !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %24 = ttg.local_load %arg11 : !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %25 = ttg.local_load %arg12 : !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> %27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> - %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %29 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked1> %30 = tt.load %27, %29 : tensor<16x16x!tt.ptr, #blocked1> - %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> %32 = tt.broadcast %31 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> %33 = arith.muli %arg0, %32 : tensor<16x16xi64, #blocked> %34 = tt.addptr %arg5, %33 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> %35 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked> %36 = tt.load %34, %35 : tensor<16x16x!tt.ptr, #blocked> - %37 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %37 = tt.splat %21 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> %39 = arith.addi %arg10, %c1_i32 : i32 %40 = arith.cmpi slt, %39, %c1_i32 : i32 %41 = arith.select %40, %39, %c0_i32 : i32 - %42 = triton_gpu.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %43 = triton_gpu.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !triton_gpu.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !triton_gpu.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %42 = ttg.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> + ttg.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> + %43 = ttg.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> + ttg.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> + scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> } - triton_gpu.local_dealloc %0 : !triton_gpu.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %1 : !triton_gpu.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> tt.return %19#0 : tensor<16x16xf32, #mma> } } @@ -463,18 +463,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: sink_convert_dealloc -// CHECK-COUNT-2: triton_gpu.local_dealloc %{{.+}} : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> -// CHECK: triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-COUNT-2: ttg.local_dealloc %{{.+}} : !ttg.memdesc<4x128x64xf16, #shared, mutable> +// CHECK: ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> - %1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> - %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> - triton_gpu.local_dealloc %0 : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> - triton_gpu.local_dealloc %1 : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> + %2 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, mutable> %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> tt.return } @@ -485,17 +485,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK-LABEL: anchor_barrier // CHECK: gpu.barrier // CHECK: tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> gpu.barrier %2 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %1 = triton_gpu.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> - triton_gpu.local_dealloc %0 : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> - triton_gpu.local_dealloc %1 : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> + %1 = ttg.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<4x128x64xf16, #shared, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, mutable> tt.return } } @@ -503,13 +503,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: dont_hoist_scf_ops // Make sure we don't hoist scf ops above its dependencies. tt.func public @dont_hoist_scf_ops(%init: tensor<256x128xf32, #mfma>, - %base: tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, - %p1: tensor<128x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>, %i1: i1) -> (tensor<256x128xf32, #mfma>) { + %base: tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, + %p1: tensor<128x128x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>, %i1: i1) -> (tensor<256x128xf32, #mfma>) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c4_i32 = arith.constant 4 : i32 @@ -521,16 +521,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %f = arith.addi %arg21, %c128_i32 : i32 // CHECK: scf.if // CHECK: tt.load - %p0 = scf.if %i1 -> tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>{ + %p0 = scf.if %i1 -> tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>{ %t = tt.splat %f : i32 -> tensor<256x128xi32> - %padd = tt.addptr %base, %t : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, tensor<256x128xi32> - scf.yield %padd : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %padd = tt.addptr %base, %t : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, tensor<256x128xi32> + scf.yield %padd : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> } else { - scf.yield %base : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + scf.yield %base : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> } - %l = tt.load %p0 : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %r = tt.load %p1 : tensor<128x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %acc = tt.dot %l, %r, %arg : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %l = tt.load %p0 : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %r = tt.load %p1 : tensor<128x128x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %acc = tt.dot %l, %r, %arg : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> scf.yield %acc : tensor<256x128xf32, #mfma> } tt.return %54 : tensor<256x128xf32, #mfma> diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir index 09c71215f9..248a04a3c0 100644 --- a/test/TritonGPU/amd/amd-sched-2nd-load.mlir +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -3,13 +3,13 @@ // Check the logic of sched-2nd-load optimizations // -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough // The following tile sizes should apply the optimization @@ -27,21 +27,21 @@ // CHECK-NEXT: local_load // CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !triton_gpu.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %1 = triton_gpu.local_load %A_LDS : !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> - %2 = triton_gpu.local_load %B_LDS : !triton_gpu.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !triton_gpu.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> @@ -51,13 +51,13 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // Should apply: tile size 256x256x64 with single dot // CHECK-LABEL: sink_2nd_load_256x256x64 @@ -66,21 +66,21 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // CHECK-NEXT: local_load // CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !triton_gpu.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> - %1 = triton_gpu.local_load %A_LDS : !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - %2 = triton_gpu.local_load %B_LDS : !triton_gpu.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !triton_gpu.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> @@ -90,13 +90,13 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // Should NOT apply: tile size 256x64x128 with single dot // CHECK-LABEL: sink_2nd_load_256x64x128 @@ -105,21 +105,21 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // CHECK-NEXT: local_load // CHECK-NEXT: local_load // CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !triton_gpu.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) { +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %1 = triton_gpu.local_load %A_LDS : !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> - %2 = triton_gpu.local_load %B_LDS : !triton_gpu.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> - triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !triton_gpu.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !triton_gpu.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %3 : tensor<256x64xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> @@ -129,13 +129,13 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // Should NOT apply: tile size 256x256x32 with single dot // CHECK-LABEL: sink_2nd_load_256x256x32 @@ -144,21 +144,21 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // CHECK-NEXT: local_load // CHECK-NEXT: local_load // CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -// CHECK-NEXT: triton_gpu.local_store %[[tileB]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !triton_gpu.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !triton_gpu.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<32x256xf16, #shared1, #ttg.shared_memory, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> - %1 = triton_gpu.local_load %A_LDS : !triton_gpu.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> - %2 = triton_gpu.local_load %B_LDS : !triton_gpu.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<32x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !triton_gpu.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !triton_gpu.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !ttg.memdesc<32x256xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> @@ -168,13 +168,13 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // Category 2: single dot with two loads and tile size is large enough (128x128x128). // We make sure the move is legal. @@ -186,20 +186,20 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // CHECK-NEXT: local_load // CHECK-NEXT: tt.store // CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store %[[tileA]] -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !triton_gpu.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { +// CHECK-NEXT: ttg.local_store %[[tileA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> - %1 = triton_gpu.local_load %A_LDS : !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> - %2 = triton_gpu.local_load %B_LDS : !triton_gpu.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> - triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> scf.yield %3 : tensor<128x128xf32, #mma> } tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> @@ -215,33 +215,33 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // CHECK-LABEL: sink_2nd_load_256x256x64_two_dot // CHECK: tt.load // CHECK-NEXT: tt.load -// CHECK-NEXT: triton_gpu.local_load -// CHECK-NEXT: triton_gpu.local_load +// CHECK-NEXT: ttg.local_load +// CHECK-NEXT: ttg.local_load // CHECK-NEXT: tt.dot // CHECK-NEXT: tt.dot -// CHECK-NEXT: triton_gpu.local_store -// CHECK-NEXT: triton_gpu.local_store -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> -#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !triton_gpu.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { +// CHECK-NEXT: ttg.local_store +// CHECK-NEXT: ttg.local_store +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - %1 = triton_gpu.local_load %A_LDS : !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !triton_gpu.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !triton_gpu.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !triton_gpu.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> diff --git a/test/TritonGPU/amd/optimize-lds-usage.mlir b/test/TritonGPU/amd/optimize-lds-usage.mlir index 61d1861b29..e9d71ed908 100644 --- a/test/TritonGPU/amd/optimize-lds-usage.mlir +++ b/test/TritonGPU/amd/optimize-lds-usage.mlir @@ -4,18 +4,18 @@ // Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS // CHECK-LABEL: alloc_convert_load // CHECK-32KLIMIT-LABEL: alloc_convert_load -// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 -// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma -// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} { - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> - %2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> - %3 = triton_gpu.local_load %1 : !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %2 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -26,18 +26,18 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // in case of relatively small scratch buffer // CHECK-LABEL: alloc_convert_small_load // CHECK-32KLIMIT-LABEL: alloc_convert_small_load -// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 -// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma -// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} { - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> - %2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma> - %3 = triton_gpu.local_load %1 : !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %2 = ttg.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -48,18 +48,18 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // in case of relatively small scratch buffer // CHECK-LABEL: alloc_convert_3d_load // CHECK-32KLIMIT-LABEL: alloc_convert_3d_load -// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma -// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#mma{{.*}}#mma1 -// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma +// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#mma{{.*}}#mma1 +// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2], hasLeadingOffset = false}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x128x128xf16, #blocked>) attributes {noinline = false} { - %1 = triton_gpu.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !triton_gpu.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory> - %2 = triton_gpu.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma> - %3 = triton_gpu.local_load %1 : !triton_gpu.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<1x128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %1 = ttg.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory> + %2 = ttg.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory> -> tensor<1x128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -68,22 +68,22 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // Check that optimization triggers with custom LDS limit and do not triggers with default one // CHECK-LABEL: alloc_convert_32k_limit -// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma -// CHECK: %2 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma +// CHECK: %2 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> // CHECK-32KLIMIT-LABEL: alloc_convert_32k_limit -// CHECK-32KLIMIT: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK-32KLIMIT: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 -// CHECK-32KLIMIT: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma -// CHECK-32KLIMIT: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK-32KLIMIT: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK-32KLIMIT: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK-32KLIMIT: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK-32KLIMIT: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<64x128xf16, #blocked>) attributes {noinline = false} { - %1 = triton_gpu.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !triton_gpu.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> - %2 = triton_gpu.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma> - %3 = triton_gpu.local_load %1 : !triton_gpu.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> + %1 = ttg.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #ttg.shared_memory> + %2 = ttg.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<64x128xf16, #shared, #ttg.shared_memory> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> tt.return } } @@ -91,30 +91,30 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // ----- // Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion) -// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}> -// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> // CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}}) -// CHECK: [[ALLOC:%[0-9]+]] = triton_gpu.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !triton_gpu.memdesc<128x128xf16, [[SHARED]], #triton_gpu.shared_memory> -// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = triton_gpu.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]> -// CHECK: [[CONVERT_1:%[0-9]+]] = triton_gpu.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]> -// CHECK: [[CONVERT_2:%[0-9]+]] = triton_gpu.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>> -// CHECK: [[LOAD:%[0-9]+]] = triton_gpu.local_load [[ALLOC]] : !triton_gpu.memdesc<128x128xf16, [[SHARED]], #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma1 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -#mma2 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#dotop1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma1, kWidth=4}> -#dotop2 = #triton_gpu.dot_op<{opIdx=0, parent=#mma2, kWidth=4}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #ttg.shared_memory> +// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]> +// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]> +// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>> +// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#mma2 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#dotop1 = #ttg.dot_op<{opIdx=0, parent=#mma1, kWidth=4}> +#dotop2 = #ttg.dot_op<{opIdx=0, parent=#mma2, kWidth=4}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @mfma_dot_shortcut(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>, %arg2: tensor<256x128xf16, #mma2>) attributes {noinline = false} { - %alloc = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> - %convert_1 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma1> - %convert_2 = triton_gpu.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2> - %load = triton_gpu.local_load %alloc : !triton_gpu.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #dotop1> + %alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %convert_1 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma1> + %convert_2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf16, #dotop1> tt.return } } @@ -123,17 +123,17 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // Checks that optimization do not crash on 1d tensor // CHECK-LABEL: convert_1d -// CHECK: triton_gpu.local_alloc -// CHECK-NEXT: triton_gpu.convert_layout -// CHECK-NEXT: triton_gpu.local_load -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} { - %alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !triton_gpu.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> - %1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> - %load = triton_gpu.local_load %alloc : !triton_gpu.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma> +// CHECK: ttg.local_alloc +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: ttg.local_load +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @convert_1d(%arg0: tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} { + %alloc = ttg.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory> + %1 = ttg.convert_layout %arg0 : tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> tt.return } } diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 64385d9297..70147ddfdf 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -3,16 +3,16 @@ // CHECK-LABEL: @test_canonicalize_convert_view // CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder // CHECK: tt.return %[[V]] -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { - %c = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> + %c = ttg.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> %r = tt.reshape %c allow_reorder : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } @@ -24,15 +24,15 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> // is an expensive view which would require moving data across threads. // CHECK-LABEL: @test_canonicalize_convert_expensive_view // CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 -// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]] +// CHECK: %[[C:.+]] = ttg.convert_layout %[[ARG]] // CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder // CHECK: tt.return %[[V]] -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { - %c = triton_gpu.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2> + %c = ttg.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2> %r = tt.reshape %c allow_reorder : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } @@ -42,18 +42,18 @@ tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blo // CHECK-LABEL: @test_canonicalize_convert_histogram // CHECK-SAME: (%[[ARG:.+]]: tensor<256xi32 -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: %[[V:.+]] = tt.histogram %[[ARG]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %[[V]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>) -> tensor<512xi32, #blocked2> { - %0 = triton_gpu.convert_layout %arg0 : tensor<256xi32, #blocked1> -> tensor<256xi32, #blocked> + %0 = ttg.convert_layout %arg0 : tensor<256xi32, #blocked1> -> tensor<256xi32, #blocked> %1 = tt.histogram %0 : tensor<256xi32, #blocked> -> tensor<512xi32, #blocked> - %2 = triton_gpu.convert_layout %1 : tensor<512xi32, #blocked> -> tensor<512xi32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<512xi32, #blocked> -> tensor<512xi32, #blocked2> tt.return %2 : tensor<512xi32, #blocked2> } } // end module @@ -62,74 +62,74 @@ tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>) // CHECK-LABEL: @test_canonicalize_convert_local_load // CHECK-NOT: gpu.barrier -// CHECK: %[[V:.+]] = triton_gpu.local_load +// CHECK: %[[V:.+]] = ttg.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: tt.return %[[V]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.compute-capability" = 80} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.compute-capability" = 80} { tt.func @test_canonicalize_convert_local_load() -> tensor<256xi32, #blocked1> { - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<256xi32, #shared, mutable> - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<256xi32, #shared, mutable> -> tensor<256xi32, #blocked> + %0 = ttg.local_alloc : () -> !ttg.memdesc<256xi32, #shared, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<256xi32, #shared, mutable> -> tensor<256xi32, #blocked> gpu.barrier - %2 = triton_gpu.convert_layout %1 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + %2 = ttg.convert_layout %1 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> tt.return %2 : tensor<256xi32, #blocked1> } } // end module // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: local_alloc_nofold1 - tt.func @local_alloc_nofold1(%arg0: tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> { - // CHECK: %[[ARG:.+]] = triton_gpu.local_alloc - // CHECK-NEXT: %[[ARG2:.+]] = triton_gpu.local_load %[[ARG]] - // CHECK-NEXT: %[[ARG3:.+]] = triton_gpu.local_alloc %[[ARG2]] + tt.func @local_alloc_nofold1(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> { + // CHECK: %[[ARG:.+]] = ttg.local_alloc + // CHECK-NEXT: %[[ARG2:.+]] = ttg.local_load %[[ARG]] + // CHECK-NEXT: %[[ARG3:.+]] = ttg.local_alloc %[[ARG2]] // CHECK-NEXT: tt.return %[[ARG3]] - %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> - tt.return %2 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> } } // end module // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared1 = #ttg.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: local_alloc_nofold2 - tt.func @local_alloc_nofold2(%arg0: tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> { - // CHECK: %[[ARG:.+]] = triton_gpu.local_alloc - // CHECK-NEXT: %[[ARG2:.+]] = triton_gpu.local_load %[[ARG]] - // CHECK-NEXT: %[[ARG3:.+]] = triton_gpu.local_alloc %[[ARG2]] + tt.func @local_alloc_nofold2(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #ttg.shared_memory> { + // CHECK: %[[ARG:.+]] = ttg.local_alloc + // CHECK-NEXT: %[[ARG2:.+]] = ttg.local_load %[[ARG]] + // CHECK-NEXT: %[[ARG3:.+]] = ttg.local_alloc %[[ARG2]] // CHECK-NEXT: tt.return %[[ARG3]] - %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> - tt.return %2 : !triton_gpu.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #ttg.shared_memory> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared1, #ttg.shared_memory> } } // end module // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - tt.func @local_alloc_fold(%arg0: tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @local_alloc_fold(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> { // CHECK-LABEL: local_alloc_fold - // CHECK-NEXT: %[[ARG:.+]] = triton_gpu.local_alloc + // CHECK-NEXT: %[[ARG:.+]] = ttg.local_alloc // CHECK-NEXT: tt.return %[[ARG]] - %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> - %1 = triton_gpu.local_load %0 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> - tt.return %2 : !triton_gpu.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> } } // end module diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 5d35f43e9e..25e136514b 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -1,22 +1,22 @@ // RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - -// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> -// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> -// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#slice2dim0 = #ttg.slice<{dim = 0, parent = #blocked2}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: [[row_layout:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: [[col_layout:#.*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: [[load_ptr:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> +// CHECK: [[load_mask:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> +// CHECK: [[load_other:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]> // CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] : tensor<64x64x!tt.ptr, [[row_layout]]> -// CHECK: [[store_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> -// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> -// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> +// CHECK: [[store_ptr:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> +// CHECK: [[store_val:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> +// CHECK: [[store_mask:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> // CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, @@ -34,7 +34,7 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> @@ -42,7 +42,7 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> %15 = tt.broadcast %12 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %16 = tt.broadcast %14 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %17 = triton_gpu.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %17 = ttg.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %19 = tt.load %10, %cst, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> tt.store %18, %19, %cst : tensor<64x64x!tt.ptr, #blocked1> @@ -53,12 +53,12 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { -// CHECK: [[NARROW_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -// CHECK: [[WIDE_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[NARROW_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[WIDE_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %0 = tt.get_program_id x : i32 @@ -87,11 +87,11 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-NOT: sizePerThread = [4] -// CHECK: #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK-NOT: sizePerThread = [4] tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 @@ -124,7 +124,7 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 // COM: Reproducer for issue #3866 // CHECK-LABEL: @test_3866 // CHECK: tt.load {{.*}} : !tt.ptr -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { tt.func public @test_3866(%arg0: !tt.ptr, %arg1: i32, %arg2: i64) { %0 = tt.make_tensor_ptr %arg0, [%arg2, %arg2], [%arg2, %arg2], [%arg1, %arg1] {order = array} : > %1 = tt.load %0 : !tt.ptr> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 129eb8c101..c045da042f 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1,20 +1,20 @@ // RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s -#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#layout0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#layout1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#layout2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#layout3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#layout2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#layout3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { -// CHECK: [[$target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[$target_layout:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK-LABEL: cst tt.func @cst() -> tensor<1024xi32, #layout1> { %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> - %1 = triton_gpu.convert_layout %cst : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> - // CHECK-NOT: triton_gpu.convert_layout + %1 = ttg.convert_layout %cst : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %cst : tensor<1024xi32, [[$target_layout]]> tt.return %1: tensor<1024xi32, #layout1> } @@ -22,8 +22,8 @@ tt.func @cst() -> tensor<1024xi32, #layout1> { // CHECK-LABEL: range tt.func @range() -> tensor<1024xi32, #layout1> { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> - %1 = triton_gpu.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> - // CHECK-NOT: triton_gpu.convert_layout + %1 = ttg.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]> tt.return %1: tensor<1024xi32, #layout1> } @@ -31,8 +31,8 @@ tt.func @range() -> tensor<1024xi32, #layout1> { // CHECK-LABEL: splat tt.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { %0 = tt.splat %arg0 : i32 -> tensor<1024xi32, #layout0> - %1 = triton_gpu.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> - // CHECK-NOT: triton_gpu.convert_layout + %1 = ttg.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]> tt.return %1: tensor<1024xi32, #layout1> } @@ -42,9 +42,9 @@ tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> %2 = arith.muli %0, %1 : tensor<1024xi32, #layout0> - %3 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %3 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> %4 = tt.splat %arg0 : i32 -> tensor<1024xi32, #layout0> - %5 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %5 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> %6 = arith.addi %3, %5 : tensor<1024xi32, #layout1> tt.return %6: tensor<1024xi32, #layout1> // CHECK: %[[A:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> @@ -59,9 +59,9 @@ tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { tt.func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { %0 = tt.splat %arg : !tt.ptr -> tensor<1x!tt.ptr, #layout1> %1 = tt.load %0 : tensor<1x!tt.ptr, #layout1> - // CHECK-NOT: triton_gpu.convert_layout - %2 = triton_gpu.convert_layout %1 : tensor<1xi32, #layout1> -> tensor<1xi32, #layout0> - %3 = triton_gpu.convert_layout %0 : tensor<1x!tt.ptr, #layout1> -> tensor<1x!tt.ptr, #layout0> + // CHECK-NOT: ttg.convert_layout + %2 = ttg.convert_layout %1 : tensor<1xi32, #layout1> -> tensor<1xi32, #layout0> + %3 = ttg.convert_layout %0 : tensor<1x!tt.ptr, #layout1> -> tensor<1x!tt.ptr, #layout0> tt.store %3, %2 : tensor<1x!tt.ptr, #layout0> tt.return } @@ -72,9 +72,9 @@ tt.func @remat_fast_load(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1> %2 = tt.addptr %0, %1 : tensor<16x!tt.ptr, #layout1>, tensor<16xi32, #layout1> %3 = tt.load %2 : tensor<16x!tt.ptr, #layout1> - // CHECK-NOT: triton_gpu.convert_layout - %4 = triton_gpu.convert_layout %3 : tensor<16xi32, #layout1> -> tensor<16xi32, #layout0> - %5 = triton_gpu.convert_layout %2 : tensor<16x!tt.ptr, #layout1> -> tensor<16x!tt.ptr, #layout0> + // CHECK-NOT: ttg.convert_layout + %4 = ttg.convert_layout %3 : tensor<16xi32, #layout1> -> tensor<16xi32, #layout0> + %5 = ttg.convert_layout %2 : tensor<16x!tt.ptr, #layout1> -> tensor<16x!tt.ptr, #layout0> tt.store %5, %4 : tensor<16x!tt.ptr, #layout0> tt.return } @@ -82,71 +82,71 @@ tt.func @remat_fast_load(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { // Hoist the convert on top of ext to make it cheaper. // CHECK-LABEL: hoist_above_ext tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tensor<1024xf32, #layout1> { -// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: %[[CVT:.+]] = ttg.convert_layout // CHECK: arith.extf %[[CVT]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> %1 = tt.splat %arg1 : f32 -> tensor<1024xf32, #layout0> %2 = arith.addf %0, %1 : tensor<1024xf32, #layout0> - %3 = triton_gpu.convert_layout %2 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + %3 = ttg.convert_layout %2 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> tt.return %3 : tensor<1024xf32, #layout1> } // CHECK-LABEL: hoist_above_ext2 tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tensor<1024xf32, #layout1> { -// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: %[[CVT:.+]] = ttg.convert_layout // CHECK: arith.extf %[[CVT]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> %1 = tt.splat %arg1 : f16 -> tensor<1024xf16, #layout0> %2 = arith.extf %1 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> %3 = arith.addf %0, %2 : tensor<1024xf32, #layout0> - %4 = triton_gpu.convert_layout %3 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + %4 = ttg.convert_layout %3 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> tt.return %4 : tensor<1024xf32, #layout1> } /// CHECK-LABEL: hoist_above_fptofp tt.func @hoist_above_fptofp(%arg0: tensor<1024xf8E4M3FNUZ, #layout0>) -> tensor<1024xf32, #layout1> { -// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: %[[CVT:.+]] = ttg.convert_layout // CHECK: tt.fp_to_fp %[[CVT]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf8E4M3FNUZ, #layout0> -> tensor<1024xf32, #layout0> - %1 = triton_gpu.convert_layout %0 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> tt.return %1 : tensor<1024xf32, #layout1> } /// CHECK-LABEL: dont_hoist_above_trunc_fptofp tt.func @dont_hoist_above_trunc_fptofp(%arg0: tensor<1024xf32, #layout0>) -> tensor<1024xf8E4M3FNUZ, #layout1> { -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: %[[FP8:.+]] = tt.fp_to_fp -// CHECK: triton_gpu.convert_layout %[[FP8]] +// CHECK: ttg.convert_layout %[[FP8]] // CHECK: tt.return %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32, #layout0> -> tensor<1024xf8E4M3FNUZ, #layout0> - %1 = triton_gpu.convert_layout %0 : tensor<1024xf8E4M3FNUZ, #layout0> -> tensor<1024xf8E4M3FNUZ, #layout1> + %1 = ttg.convert_layout %0 : tensor<1024xf8E4M3FNUZ, #layout0> -> tensor<1024xf8E4M3FNUZ, #layout1> tt.return %1 : tensor<1024xf8E4M3FNUZ, #layout1> } // Hoist the convert on top of broadcast to make it cheaper. // CHECK-LABEL: hoist_above_broadcast tt.func @hoist_above_broadcast(%arg0: tensor<1024x1xf32, #layout2>, %arg1: f32) -> tensor<1024x128xf32, #layout3> { -// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: %[[CVT:.+]] = ttg.convert_layout // CHECK: tt.broadcast %[[CVT]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 = tt.broadcast %arg0 : tensor<1024x1xf32, #layout2> -> tensor<1024x128xf32, #layout2> %1 = tt.splat %arg1 : f32 -> tensor<1024x128xf32, #layout2> %2 = arith.addf %0, %1 : tensor<1024x128xf32, #layout2> - %3 = triton_gpu.convert_layout %2 : tensor<1024x128xf32, #layout2> -> tensor<1024x128xf32, #layout3> + %3 = ttg.convert_layout %2 : tensor<1024x128xf32, #layout2> -> tensor<1024x128xf32, #layout3> tt.return %3 : tensor<1024x128xf32, #layout3> } // CHECK-LABEL: if tt.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1> %0 = tt.get_program_id x : i32 %1 = tt.splat %0 : i32 -> tensor<1024xi32, #layout1> @@ -155,7 +155,7 @@ tt.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %4 = arith.cmpi sgt, %0, %arg0 : i32 %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout0> scf.if %4 { - %6 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout1> -> tensor<1024xi32, #layout0> + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout1> -> tensor<1024xi32, #layout0> tt.store %5, %6 : tensor<1024x!tt.ptr, #layout0> } tt.return @@ -172,12 +172,12 @@ tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = %4 = arith.cmpi sgt, %0, %arg0 : i32 %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout1> %8 = scf.if %4 -> tensor<1024xi32, #layout1> { - %6 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> scf.yield %6 : tensor<1024xi32, #layout1> } else { scf.yield %9 : tensor<1024xi32, #layout1> } - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> tt.return } @@ -195,10 +195,10 @@ tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = %8 = scf.if %4 -> tensor<1024xi32, #layout1> { scf.yield %9 : tensor<1024xi32, #layout1> } else { - %7 = triton_gpu.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %7 = ttg.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> scf.yield %7 : tensor<1024xi32, #layout1> } - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> tt.return } @@ -213,15 +213,15 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = %4 = arith.cmpi sgt, %0, %arg0 : i32 %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout1> %8 = scf.if %4 -> tensor<1024xi32, #layout1> { - %6 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> scf.yield %6 : tensor<1024xi32, #layout1> } else { - %7 = triton_gpu.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %7 = ttg.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> scf.yield %7 : tensor<1024xi32, #layout1> } // TODO(csigg): seems like the whole function is converted to layout1. - // disabledCHECK: triton_gpu.convert_layout - // CHECK-NOT: triton_gpu.convert_layout + // disabledCHECK: ttg.convert_layout + // CHECK-NOT: ttg.convert_layout tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> tt.return } @@ -230,27 +230,27 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked0a = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked2a = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> - -// CHECK-DAG: [[$row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -// CHECK-DAG: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> -// CHECK-DAG: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked0a = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2a = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice2dim0 = #ttg.slice<{dim = 0, parent = #blocked2}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked5 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +// CHECK-DAG: [[$row_layout:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK-DAG: [[$col_layout:#.*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +// CHECK-DAG: [[$col_layout_novec:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK-LABEL: @transpose -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]> - // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout]]> + // CHECK: [[cvt_val:%.*]] = ttg.convert_layout [[loaded_val]] : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout]]> // CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64x!tt.ptr, [[$col_layout]]> // CHECK: tt.return %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> @@ -265,7 +265,7 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> @@ -273,34 +273,34 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> %15 = tt.broadcast %12 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %16 = tt.broadcast %14 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %17 = triton_gpu.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %17 = ttg.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> - %19 = triton_gpu.convert_layout %10 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> - %20 = triton_gpu.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> - %21 = triton_gpu.convert_layout %cst : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %19 = ttg.convert_layout %10 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %20 = ttg.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %21 = ttg.convert_layout %cst : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> %22 = tt.load %19, %20, %21 : tensor<64x64x!tt.ptr, #blocked3> - %23 = triton_gpu.convert_layout %22 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> - %24 = triton_gpu.convert_layout %18 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked4> - %25 = triton_gpu.convert_layout %23 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked4> - %26 = triton_gpu.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked4> + %23 = ttg.convert_layout %22 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + %24 = ttg.convert_layout %18 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked4> + %25 = ttg.convert_layout %23 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked4> + %26 = ttg.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked4> tt.store %24, %25, %26 : tensor<64x64x!tt.ptr, #blocked4> tt.return } } // CHECK-LABEL: loop -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]>) // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]> // CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[$row_layout]]> // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]> // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]> // CHECK-NEXT: } - // CHECK-NOT: triton_gpu.convert_layout - // CHECK: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout_novec]]> - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout + // CHECK: {{.*}} = ttg.convert_layout [[loop_ret]]#0 : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout_novec]]> + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> @@ -318,14 +318,14 @@ tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { - %23 = triton_gpu.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> - %24 = triton_gpu.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> - %25 = triton_gpu.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %23 = ttg.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %25 = ttg.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> %26 = tt.load %23, %24, %25 : tensor<64x64x!tt.ptr, #blocked3> - %27 = triton_gpu.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + %27 = ttg.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> %28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1> %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1> @@ -336,31 +336,31 @@ tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> %16 = tt.broadcast %13 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %17 = tt.broadcast %15 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %18 = triton_gpu.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %18 = ttg.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> - %20 = triton_gpu.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> - %21 = triton_gpu.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> - %22 = triton_gpu.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> + %20 = ttg.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %21 = ttg.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> + %22 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> tt.store %20, %21, %22 : tensor<64x64x!tt.ptr, #blocked1> tt.return } } // CHECK-LABEL: loop_if -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.for -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.if -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.yield // CHECK: else // CHECK: scf.yield -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.yield -// CHECK: triton_gpu.convert_layout -// CHECK-NOT: triton_gpu.convert_layout +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.store -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> @@ -379,16 +379,16 @@ tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i3 %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { %33 = arith.cmpi "sgt", %arg5, %c0 : index %34 = scf.if %33 -> (tensor<64x64xf32, #blocked1>) { - %23 = triton_gpu.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> - %24 = triton_gpu.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> - %25 = triton_gpu.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %23 = ttg.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %25 = ttg.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> %26 = tt.load %23, %24, %25 : tensor<64x64x!tt.ptr, #blocked3> - %27 = triton_gpu.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + %27 = ttg.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> scf.yield %27 : tensor<64x64xf32, #blocked1> } else { scf.yield %arg6 : tensor<64x64xf32, #blocked1> @@ -403,20 +403,20 @@ tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i3 %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> %16 = tt.broadcast %13 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %17 = tt.broadcast %15 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %18 = triton_gpu.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %18 = ttg.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> - %20 = triton_gpu.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> - %21 = triton_gpu.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> - %22 = triton_gpu.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> + %20 = ttg.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %21 = ttg.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> + %22 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> tt.store %20, %21, %22 : tensor<64x64x!tt.ptr, #blocked1> tt.return } } // CHECK-LABEL: vecadd -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c256_i32 : i32 @@ -432,15 +432,15 @@ tt.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %11 = arith.addi %4, %5 : tensor<256xi32, #blocked5> %12 = tt.addptr %8, %9 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> %13 = tt.load %12 : tensor<256x!tt.ptr, #blocked5> - %14 = triton_gpu.convert_layout %13 : tensor<256xf32, #blocked5> -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %14 = ttg.convert_layout %13 : tensor<256xf32, #blocked5> -> tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %15 = tt.addptr %10, %11 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> %16 = tt.load %15 : tensor<256x!tt.ptr, #blocked5> - %17 = triton_gpu.convert_layout %16 : tensor<256xf32, #blocked5> -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %17 = ttg.convert_layout %16 : tensor<256xf32, #blocked5> -> tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %18 = arith.addf %14, %17 : tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %19 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked5> %20 = arith.addi %2, %3 : tensor<256xi32, #blocked5> %21 = tt.addptr %19, %20 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> - %22 = triton_gpu.convert_layout %18 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<256xf32, #blocked5> + %22 = ttg.convert_layout %18 : tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<256xf32, #blocked5> tt.store %21, %22 : tensor<256x!tt.ptr, #blocked5> tt.return } @@ -448,9 +448,9 @@ tt.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr // Select has args with different element types // CHECK-LABEL: select -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2> %cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2> %c512 = arith.constant 512 : index @@ -460,15 +460,15 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2> %0 = tt.get_program_id x : i32 %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0> - %2 = triton_gpu.convert_layout %1 : tensor<1xi32, #blocked0> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1xi32, #blocked1> - %4 = triton_gpu.convert_layout %3 : tensor<1x1xi32, #blocked1> -> tensor<1x1xi32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<1xi32, #blocked0> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1xi32, #blocked1> + %4 = ttg.convert_layout %3 : tensor<1x1xi32, #blocked1> -> tensor<1x1xi32, #blocked2> %5 = tt.splat %0 : i32 -> tensor<1x1xi32, #blocked2> %6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked2> %7 = arith.cmpi "slt", %6, %cst_1 : tensor<1x1xi32, #blocked2> %8 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked0> - %9 = triton_gpu.convert_layout %8 : tensor<512xi32, #blocked0> -> tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x512xi32, #blocked2> + %9 = ttg.convert_layout %8 : tensor<512xi32, #blocked0> -> tensor<512xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<512xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x512xi32, #blocked2> %11 = arith.muli %6, %cst : tensor<1x1xi32, #blocked2> %12 = tt.broadcast %11 : tensor<1x1xi32, #blocked2> -> tensor<1x512xi32, #blocked2> %13 = tt.splat %arg0 : !tt.ptr -> tensor<1x512x!tt.ptr, #blocked2> @@ -481,17 +481,17 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2> %21 = tt.addptr %13, %20 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> %22 = arith.andi %19, %14 : tensor<1x512xi1, #blocked2> - %23 = triton_gpu.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> - %24 = triton_gpu.convert_layout %22 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> + %23 = ttg.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %22 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> %25 = tt.load %23, %24 : tensor<1x512x!tt.ptr, #blocked3> - %26 = triton_gpu.convert_layout %25 : tensor<1x512xf64, #blocked3> -> tensor<1x512xf64, #blocked2> + %26 = ttg.convert_layout %25 : tensor<1x512xf64, #blocked3> -> tensor<1x512xf64, #blocked2> %27 = arith.andi %14, %19 : tensor<1x512xi1, #blocked2> %28 = arith.cmpf "olt", %arg4, %26 : tensor<1x512xf64, #blocked2> %29 = arith.andi %27, %28 : tensor<1x512xi1, #blocked2> %30 = arith.select %29, %26, %arg4 : tensor<1x512xi1, #blocked2>, tensor<1x512xf64, #blocked2> - %31 = triton_gpu.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> - %32 = triton_gpu.convert_layout %30 : tensor<1x512xf64, #blocked2> -> tensor<1x512xf64, #blocked3> - %33 = triton_gpu.convert_layout %27 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> + %31 = ttg.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> + %32 = ttg.convert_layout %30 : tensor<1x512xf64, #blocked2> -> tensor<1x512xf64, #blocked3> + %33 = ttg.convert_layout %27 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> tt.store %31, %32, %33 : tensor<1x512x!tt.ptr, #blocked3> scf.yield %30 : tensor<1x512xf64, #blocked2> } @@ -501,7 +501,7 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr // Make sure the following IR doesn't hang the compiler. // CHECK-LABEL: long_func -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: !tt.ptr {tt.divisibility = 16 : i32}, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: !tt.ptr {tt.divisibility = 16 : i32}, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0> %cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0> @@ -529,22 +529,22 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %5 = arith.cmpi "slt", %4, %cst_11 : tensor<1024xi32, #blocked0> %6 = tt.splat %arg5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %8 = triton_gpu.convert_layout %7 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %9 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %8 = ttg.convert_layout %7 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %9 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> %10 = tt.load %8, %9 : tensor<1024x!tt.ptr, #blocked0a> - %11 = triton_gpu.convert_layout %10 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> + %11 = ttg.convert_layout %10 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> %12 = tt.splat %arg7 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %13 = tt.addptr %12, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %14 = triton_gpu.convert_layout %13 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked2a> - %15 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked2a> + %14 = ttg.convert_layout %13 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked2a> + %15 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked2a> %16 = tt.load %14, %15 : tensor<1024x!tt.ptr, #blocked2a> - %17 = triton_gpu.convert_layout %16 : tensor<1024xi64, #blocked2a> -> tensor<1024xi64, #blocked0> + %17 = ttg.convert_layout %16 : tensor<1024xi64, #blocked2a> -> tensor<1024xi64, #blocked0> %18 = tt.splat %arg8 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %19 = tt.addptr %18, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %20 = triton_gpu.convert_layout %19 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %21 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %20 = ttg.convert_layout %19 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %21 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> %22 = tt.load %20, %21 : tensor<1024x!tt.ptr, #blocked0a> - %23 = triton_gpu.convert_layout %22 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> + %23 = ttg.convert_layout %22 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> %24 = arith.subf %cst_13, %11 : tensor<1024xf32, #blocked0> %25 = math.exp %24 : tensor<1024xf32, #blocked0> %26 = arith.sitofp %cst_12 : tensor<1024xi32, #blocked0> to tensor<1024xf32, #blocked0> @@ -575,7 +575,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %51 = arith.select %50, %49, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %52 = tt.splat %arg6 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %53 = tt.addptr %52, %51 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %54 = triton_gpu.convert_layout %53 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %54 = ttg.convert_layout %53 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %55 = tt.load %54 : tensor<1024x!tt.ptr, #blocked0> %56 = arith.cmpf "oge", %55, %35 :tensor<1024xf32, #blocked0> %57 = arith.cmpi "eq", %56, %cst_5 : tensor<1024xi1, #blocked0> @@ -597,7 +597,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %73 = arith.addi %60, %72 : tensor<1024xi32, #blocked0> %74 = arith.select %63, %73, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %75 = tt.addptr %52, %74 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %76 = triton_gpu.convert_layout %75 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %76 = ttg.convert_layout %75 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %77 = tt.load %76 : tensor<1024x!tt.ptr, #blocked0> %78 = arith.cmpf "oge", %77, %35 :tensor<1024xf32, #blocked0> %79 = arith.cmpi "eq", %78, %cst_5 : tensor<1024xi1, #blocked0> @@ -619,7 +619,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %95 = arith.addi %82, %94 : tensor<1024xi32, #blocked0> %96 = arith.select %85, %95, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %97 = tt.addptr %52, %96 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %98 = triton_gpu.convert_layout %97 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %98 = ttg.convert_layout %97 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %99 = tt.load %98 : tensor<1024x!tt.ptr, #blocked0> %100 = arith.cmpf "oge", %99, %35 : tensor<1024xf32, #blocked0> %101 = arith.cmpi "eq", %100, %cst_5 : tensor<1024xi1, #blocked0> @@ -641,7 +641,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %117 = arith.addi %104, %116 : tensor<1024xi32, #blocked0> %118 = arith.select %107, %117, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %119 = tt.addptr %52, %118 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %120 = triton_gpu.convert_layout %119 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %120 = ttg.convert_layout %119 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %121 = tt.load %120 : tensor<1024x!tt.ptr, #blocked0> %122 = arith.cmpf "oge", %121, %35 : tensor<1024xf32, #blocked0> %123 = arith.cmpi "eq", %122, %cst_5 : tensor<1024xi1, #blocked0> @@ -663,7 +663,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %139 = arith.addi %126, %138 : tensor<1024xi32, #blocked0> %140 = arith.select %129, %139, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %141 = tt.addptr %52, %140 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %142 = triton_gpu.convert_layout %141 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %142 = ttg.convert_layout %141 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %143 = tt.load %142 : tensor<1024x!tt.ptr, #blocked0> %144 = arith.cmpf "oge", %143, %35 : tensor<1024xf32, #blocked0> %145 = arith.cmpi "eq", %144, %cst_5 : tensor<1024xi1, #blocked0> @@ -685,7 +685,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %161 = arith.addi %148, %160 : tensor<1024xi32, #blocked0> %162 = arith.select %151, %161, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %163 = tt.addptr %52, %162 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %164 = triton_gpu.convert_layout %163 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %164 = ttg.convert_layout %163 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %165 = tt.load %164 : tensor<1024x!tt.ptr, #blocked0> %166 = arith.cmpf "oge", %165, %35 : tensor<1024xf32, #blocked0> %167 = arith.cmpi "eq", %166, %cst_5 : tensor<1024xi1, #blocked0> @@ -707,7 +707,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %183 = arith.addi %170, %182 : tensor<1024xi32, #blocked0> %184 = arith.select %173, %183, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %185 = tt.addptr %52, %184 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %186 = triton_gpu.convert_layout %185 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %186 = ttg.convert_layout %185 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %187 = tt.load %186 : tensor<1024x!tt.ptr, #blocked0> %188 = arith.cmpf "oge", %187, %35 : tensor<1024xf32, #blocked0> %189 = arith.cmpi "eq", %188, %cst_5 : tensor<1024xi1, #blocked0> @@ -729,7 +729,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %205 = arith.addi %192, %204 : tensor<1024xi32, #blocked0> %206 = arith.select %195, %205, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %207 = tt.addptr %52, %206 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %208 = triton_gpu.convert_layout %207 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %208 = ttg.convert_layout %207 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %209 = tt.load %208 : tensor<1024x!tt.ptr, #blocked0> %210 = arith.cmpf "oge", %209, %35 :tensor<1024xf32, #blocked0> %211 = arith.cmpi "eq", %210, %cst_5 : tensor<1024xi1, #blocked0> @@ -751,7 +751,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %227 = arith.addi %214, %226 : tensor<1024xi32, #blocked0> %228 = arith.select %217, %227, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %229 = tt.addptr %52, %228 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %230 = triton_gpu.convert_layout %229 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %230 = ttg.convert_layout %229 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %231 = tt.load %230 : tensor<1024x!tt.ptr, #blocked0> %232 = arith.cmpf "oge", %231, %35 : tensor<1024xf32, #blocked0> %233 = arith.cmpi "eq", %232, %cst_5 : tensor<1024xi1, #blocked0> @@ -773,7 +773,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %249 = arith.addi %236, %248 : tensor<1024xi32, #blocked0> %250 = arith.select %239, %249, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %251 = tt.addptr %52, %250 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %252 = triton_gpu.convert_layout %251 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %252 = ttg.convert_layout %251 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %253 = tt.load %252 : tensor<1024x!tt.ptr, #blocked0> %254 = arith.cmpf "oge", %253, %35 : tensor<1024xf32, #blocked0> %255 = arith.cmpi "eq", %254, %cst_5 : tensor<1024xi1, #blocked0> @@ -795,7 +795,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %271 = arith.addi %258, %270 : tensor<1024xi32, #blocked0> %272 = arith.select %261, %271, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %273 = tt.addptr %52, %272 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %274 = triton_gpu.convert_layout %273 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %274 = ttg.convert_layout %273 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %275 = tt.load %274 : tensor<1024x!tt.ptr, #blocked0> %276 = arith.cmpf "oge", %275, %35 : tensor<1024xf32, #blocked0> %277 = arith.cmpi "eq", %276, %cst_5 : tensor<1024xi1, #blocked0> @@ -817,7 +817,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %293 = arith.addi %280, %292 : tensor<1024xi32, #blocked0> %294 = arith.select %283, %293, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %295 = tt.addptr %52, %294 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %296 = triton_gpu.convert_layout %295 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %296 = ttg.convert_layout %295 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %297 = tt.load %296 : tensor<1024x!tt.ptr, #blocked0> %298 = arith.cmpf "oge", %297, %35 :tensor<1024xf32, #blocked0> %299 = arith.cmpi "eq", %298, %cst_5 : tensor<1024xi1, #blocked0> @@ -842,13 +842,13 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %318 = arith.extsi %317 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> %319 = tt.splat %arg9 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %320 = tt.addptr %319, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> - %321 = triton_gpu.convert_layout %320 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %321 = ttg.convert_layout %320 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %322 = tt.load %321 : tensor<1024x!tt.ptr, #blocked0> %323 = arith.extf %cst_2 : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0> %324 = arith.cmpf "ogt", %322, %323 : tensor<1024xf64, #blocked0> %325 = tt.splat %arg10 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %326 = tt.addptr %325, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> - %327 = triton_gpu.convert_layout %326 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %327 = ttg.convert_layout %326 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %328 = tt.load %327 : tensor<1024x!tt.ptr, #blocked0> %329 = arith.divf %328, %322 : tensor<1024xf64, #blocked0> %330 = arith.truncf %329 : tensor<1024xf64, #blocked0> to tensor<1024xf32, #blocked0> @@ -857,41 +857,41 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %333 = arith.addf %331, %332 : tensor<1024xf32, #blocked0> %334 = arith.select %324, %333, %35 : tensor<1024xi1, #blocked0>, tensor<1024xf32, #blocked0> %335 = tt.addptr %319, %317 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %336 = triton_gpu.convert_layout %335 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %336 = ttg.convert_layout %335 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %337 = tt.load %336 : tensor<1024x!tt.ptr, #blocked0> %338 = arith.extf %cst : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0> %339 = arith.mulf %337, %338 : tensor<1024xf64, #blocked0> %340 = tt.addptr %325, %317 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %341 = triton_gpu.convert_layout %340 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %341 = ttg.convert_layout %340 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %342 = tt.load %341 : tensor<1024x!tt.ptr, #blocked0> %343 = arith.mulf %342, %338 : tensor<1024xf64, #blocked0> %344 = tt.splat %arg11 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %345 = tt.addptr %344, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %346 = triton_gpu.convert_layout %345 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %347 = triton_gpu.convert_layout %28 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> - %348 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %346 = ttg.convert_layout %345 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %347 = ttg.convert_layout %28 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> + %348 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> tt.store %346, %347, %348 : tensor<1024x!tt.ptr, #blocked0a> %349 = tt.splat %arg12 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %350 = tt.addptr %349, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %351 = triton_gpu.convert_layout %350 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %352 = triton_gpu.convert_layout %317 : tensor<1024xi32, #blocked0> -> tensor<1024xi32, #blocked0a> - %353 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %351 = ttg.convert_layout %350 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %352 = ttg.convert_layout %317 : tensor<1024xi32, #blocked0> -> tensor<1024xi32, #blocked0a> + %353 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> tt.store %351, %352, %353 : tensor<1024x!tt.ptr, #blocked0a> %354 = tt.splat %arg13 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %355 = tt.addptr %354, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %356 = triton_gpu.convert_layout %355 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %357 = triton_gpu.convert_layout %334 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> - %358 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %356 = ttg.convert_layout %355 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %357 = ttg.convert_layout %334 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> + %358 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> tt.store %356, %357, %358 : tensor<1024x!tt.ptr, #blocked0a> %359 = tt.splat %arg14 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %360 = tt.addptr %359, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> - %361 = triton_gpu.convert_layout %360 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> - %362 = triton_gpu.convert_layout %339 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> + %361 = ttg.convert_layout %360 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %362 = ttg.convert_layout %339 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> tt.store %361, %362 : tensor<1024x!tt.ptr, #blocked0> %363 = tt.splat %arg15 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %364 = tt.addptr %363, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> - %365 = triton_gpu.convert_layout %364 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> - %366 = triton_gpu.convert_layout %343 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> + %365 = ttg.convert_layout %364 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %366 = ttg.convert_layout %343 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> tt.store %365, %366 : tensor<1024x!tt.ptr, #blocked0> tt.return } @@ -900,9 +900,9 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg // A mnist model from torch inductor. // Check if topological sort is working correct and there's no unnecessary convert // CHECK-LABEL: mnist -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2> %cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3> %c16_i32 = arith.constant 16 : i32 @@ -913,30 +913,30 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c16_i32 : i32 %2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked0> - %3 = triton_gpu.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<16x1xi32, #blocked1> -> tensor<16x1xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<16x1xi32, #blocked1> -> tensor<16x1xi32, #blocked2> %6 = tt.splat %1 : i32 -> tensor<16x1xi32, #blocked2> %7 = arith.addi %6, %5 : tensor<16x1xi32, #blocked2> %8 = arith.cmpi "slt", %7, %cst_1 : tensor<16x1xi32, #blocked2> - %9 = triton_gpu.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %9 = ttg.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> %11 = arith.cmpi "slt", %10, %cst_0 : tensor<1x16xi32, #blocked3> %12 = arith.muli %7, %cst : tensor<16x1xi32, #blocked2> %13 = tt.broadcast %10 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> - %14 = triton_gpu.convert_layout %13 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked2> + %14 = ttg.convert_layout %13 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked2> %15 = tt.broadcast %12 : tensor<16x1xi32, #blocked2> -> tensor<16x16xi32, #blocked2> %16 = arith.addi %14, %15 : tensor<16x16xi32, #blocked2> %17 = tt.splat %arg0 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked2> %18 = tt.addptr %17, %16 : tensor<16x16x!tt.ptr, #blocked2>, tensor<16x16xi32, #blocked2> %19 = tt.broadcast %11 : tensor<1x16xi1, #blocked3> -> tensor<16x16xi1, #blocked3> - %20 = triton_gpu.convert_layout %19 : tensor<16x16xi1, #blocked3> -> tensor<16x16xi1, #blocked2> + %20 = ttg.convert_layout %19 : tensor<16x16xi1, #blocked3> -> tensor<16x16xi1, #blocked2> %21 = tt.broadcast %8 : tensor<16x1xi1, #blocked2> -> tensor<16x16xi1, #blocked2> %22 = arith.andi %20, %21 : tensor<16x16xi1, #blocked2> - %23 = triton_gpu.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> - %24 = triton_gpu.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %23 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %24 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> %25 = tt.load %23, %24 : tensor<16x16x!tt.ptr, #blocked4> - %26 = triton_gpu.convert_layout %25 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %26 = ttg.convert_layout %25 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> %27 = arith.cmpf "olt", %cst_2, %26 : tensor<16x16xf32, #blocked2> %28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2> %29 = arith.select %28, %26, %cst_2 : tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2> @@ -944,17 +944,17 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! ^bb0(%arg4: f32, %arg5: f32): %max = arith.maximumf %arg4, %arg5 : f32 tt.reduce.return %max : f32 - }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %31 = triton_gpu.convert_layout %30 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> - %32 = triton_gpu.convert_layout %31 : tensor<16xf32, #blocked0> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> - %34 = triton_gpu.convert_layout %33 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %31 = ttg.convert_layout %30 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> + %32 = ttg.convert_layout %31 : tensor<16xf32, #blocked0> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> + %34 = ttg.convert_layout %33 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> %35 = arith.sitofp %cst_4 : tensor<16x16xi32, #blocked2> to tensor<16x16xf32, #blocked2> %36 = arith.addf %35, %cst_3 : tensor<16x16xf32, #blocked2> - %37 = triton_gpu.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> - %38 = triton_gpu.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %37 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %38 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> %39 = tt.load %37, %38 : tensor<16x16x!tt.ptr, #blocked4> - %40 = triton_gpu.convert_layout %39 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %40 = ttg.convert_layout %39 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> %41 = tt.broadcast %34 : tensor<16x1xf32, #blocked2> -> tensor<16x16xf32, #blocked2> %42 = arith.subf %40, %41 : tensor<16x16xf32, #blocked2> %43 = math.exp %42 : tensor<16x16xf32, #blocked2> @@ -964,24 +964,24 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! ^bb0(%arg4: f32, %arg5: f32): %add = arith.addf %arg4, %arg5 : f32 tt.reduce.return %add : f32 - }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %47 = triton_gpu.convert_layout %46 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> - %48 = triton_gpu.convert_layout %47 : tensor<16xf32, #blocked0> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %49 = tt.expand_dims %48 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> - %50 = triton_gpu.convert_layout %49 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> - %51 = triton_gpu.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> - %52 = triton_gpu.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %47 = ttg.convert_layout %46 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> + %48 = ttg.convert_layout %47 : tensor<16xf32, #blocked0> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %49 = tt.expand_dims %48 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> + %50 = ttg.convert_layout %49 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> + %51 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %52 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> %53 = tt.load %51, %52 : tensor<16x16x!tt.ptr, #blocked4> - %54 = triton_gpu.convert_layout %53 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %54 = ttg.convert_layout %53 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> %55 = arith.subf %54, %41 : tensor<16x16xf32, #blocked2> %56 = math.log %50 : tensor<16x1xf32, #blocked2> %57 = tt.broadcast %56 : tensor<16x1xf32, #blocked2> -> tensor<16x16xf32, #blocked2> %58 = arith.subf %55, %57 : tensor<16x16xf32, #blocked2> %59 = tt.splat %arg1 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked2> %60 = tt.addptr %59, %16 : tensor<16x16x!tt.ptr, #blocked2>, tensor<16x16xi32, #blocked2> - %61 = triton_gpu.convert_layout %60 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> - %62 = triton_gpu.convert_layout %58 : tensor<16x16xf32, #blocked2> -> tensor<16x16xf32, #blocked4> - %63 = triton_gpu.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %61 = ttg.convert_layout %60 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %62 = ttg.convert_layout %58 : tensor<16x16xf32, #blocked2> -> tensor<16x16xf32, #blocked4> + %63 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> tt.store %61, %62, %63 : tensor<16x16x!tt.ptr, #blocked4> tt.return } @@ -989,15 +989,15 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> // cmpf and cmpi have different operands and result types // CHECK-LABEL: cmp -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { %c64 = arith.constant 64 : index %c2048 = arith.constant 2048 : index @@ -1014,14 +1014,14 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c64_i32 : i32 %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> - %3 = triton_gpu.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<64x1xi32, #blocked1> -> tensor<64x1xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<64x1xi32, #blocked1> -> tensor<64x1xi32, #blocked2> %6 = tt.splat %1 : i32 -> tensor<64x1xi32, #blocked2> %7 = arith.addi %6, %5 : tensor<64x1xi32, #blocked2> %8 = arith.cmpi "slt", %7, %cst_5 : tensor<64x1xi32, #blocked2> - %9 = triton_gpu.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x64xi32, #blocked3> + %9 = ttg.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x64xi32, #blocked3> %11 = arith.remsi %7, %cst_4 : tensor<64x1xi32, #blocked2> %12 = arith.divsi %7, %cst_4 : tensor<64x1xi32, #blocked2> %13 = arith.sitofp %cst_3 : tensor<64x64xi32, #blocked2> to tensor<64x64xf32, #blocked2> @@ -1042,24 +1042,24 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> %48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3> - %49 = triton_gpu.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> + %49 = ttg.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> %50 = arith.addi %49, %16 : tensor<64x64xi32, #blocked2> %51 = tt.addptr %17, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> %52 = tt.broadcast %47 : tensor<1x64xi1, #blocked3> -> tensor<64x64xi1, #blocked3> - %53 = triton_gpu.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> + %53 = ttg.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> %54 = arith.andi %53, %18 : tensor<64x64xi1, #blocked2> - %55 = triton_gpu.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> - %56 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + %55 = ttg.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %56 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> %57 = tt.load %55, %56 : tensor<64x64x!tt.ptr, #blocked4> - %58 = triton_gpu.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> + %58 = ttg.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> %59 = arith.extf %58 : tensor<64x64xf16, #blocked2> to tensor<64x64xf32, #blocked2> %60 = arith.addi %49, %20 : tensor<64x64xi32, #blocked2> %61 = arith.addi %60, %23 : tensor<64x64xi32, #blocked2> %62 = tt.addptr %24, %61 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> - %63 = triton_gpu.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> - %64 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + %63 = ttg.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %64 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> %65 = tt.load %63, %64 : tensor<64x64x!tt.ptr, #blocked5> - %66 = triton_gpu.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> + %66 = ttg.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> %67 = arith.addf %59, %66 : tensor<64x64xf32, #blocked2> %68 = arith.cmpf "une", %67, %67 : tensor<64x64xf32, #blocked2> %69 = arith.cmpf "ogt", %67, %cst : tensor<64x64xf32, #blocked2> @@ -1074,11 +1074,11 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt ^bb0(%arg8: f32, %arg9: f32): %add = arith.addf %arg8, %arg9 : f32 tt.reduce.return %add : f32 - }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %27 = triton_gpu.convert_layout %26 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64xf32, #blocked0> - %28 = triton_gpu.convert_layout %27 : tensor<64xf32, #blocked0> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %29 = tt.expand_dims %28 {axis = 1 : i32} : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xf32, #blocked1> - %30 = triton_gpu.convert_layout %29 : tensor<64x1xf32, #blocked1> -> tensor<64x1xf32, #blocked2> + }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %27 = ttg.convert_layout %26 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64xf32, #blocked0> + %28 = ttg.convert_layout %27 : tensor<64xf32, #blocked0> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %29 = tt.expand_dims %28 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xf32, #blocked1> + %30 = ttg.convert_layout %29 : tensor<64x1xf32, #blocked1> -> tensor<64x1xf32, #blocked2> %31 = arith.muli %7, %cst_4 : tensor<64x1xi32, #blocked2> %32 = tt.broadcast %31 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> %33 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> @@ -1098,24 +1098,24 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> %48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3> - %49 = triton_gpu.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> + %49 = ttg.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> %50 = arith.addi %49, %32 : tensor<64x64xi32, #blocked2> %51 = tt.addptr %33, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> %52 = tt.broadcast %47 : tensor<1x64xi1, #blocked3> -> tensor<64x64xi1, #blocked3> - %53 = triton_gpu.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> + %53 = ttg.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> %54 = arith.andi %53, %34 : tensor<64x64xi1, #blocked2> - %55 = triton_gpu.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> - %56 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + %55 = ttg.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %56 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> %57 = tt.load %55, %56 : tensor<64x64x!tt.ptr, #blocked4> - %58 = triton_gpu.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> + %58 = ttg.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> %59 = arith.extf %58 : tensor<64x64xf16, #blocked2> to tensor<64x64xf32, #blocked2> %60 = arith.addi %49, %36 : tensor<64x64xi32, #blocked2> %61 = arith.addi %60, %39 : tensor<64x64xi32, #blocked2> %62 = tt.addptr %40, %61 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> - %63 = triton_gpu.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> - %64 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + %63 = ttg.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %64 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> %65 = tt.load %63, %64 : tensor<64x64x!tt.ptr, #blocked5> - %66 = triton_gpu.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> + %66 = ttg.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> %67 = arith.addf %59, %66 : tensor<64x64xf32, #blocked2> %68 = arith.cmpf "une", %67, %67 : tensor<64x64xf32, #blocked2> %69 = arith.cmpf "ogt", %67, %cst : tensor<64x64xf32, #blocked2> @@ -1124,15 +1124,15 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %72 = math.exp %71 : tensor<64x64xf32, #blocked2> %73 = arith.divf %72, %41 : tensor<64x64xf32, #blocked2> %74 = tt.addptr %42, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> - %75 = triton_gpu.convert_layout %74 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> - %76 = triton_gpu.convert_layout %73 : tensor<64x64xf32, #blocked2> -> tensor<64x64xf32, #blocked5> - %77 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + %75 = ttg.convert_layout %74 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %76 = ttg.convert_layout %73 : tensor<64x64xf32, #blocked2> -> tensor<64x64xf32, #blocked5> + %77 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> tt.store %75, %76, %77 : tensor<64x64x!tt.ptr, #blocked5> %78 = tt.addptr %43, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> %79 = arith.truncf %73 : tensor<64x64xf32, #blocked2> to tensor<64x64xf16, #blocked2> - %80 = triton_gpu.convert_layout %78 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> - %81 = triton_gpu.convert_layout %79 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #blocked4> - %82 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + %80 = ttg.convert_layout %78 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %81 = ttg.convert_layout %79 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #blocked4> + %82 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> tt.store %80, %81, %82 : tensor<64x64x!tt.ptr, #blocked4> } tt.return @@ -1143,9 +1143,9 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt // Just make sure it doesn't crash on non-tensor types. // CHECK-LABEL: if_no_tensor -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @if_no_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %c-1_i64 = arith.constant -1 : i64 %cst = arith.constant 0.000000e+00 : f32 %c-1_i32 = arith.constant -1 : i32 @@ -1173,35 +1173,35 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, % // Check if the SimplifyReduceCvt rewriter pattern doesn't hang. // CHECK-LABEL: reduce_cvt -// CHECK-NOT: triton_gpu.convert_layout -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 2 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @reduce_cvt1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) { %cst = arith.constant dense<0> : tensor<1x2xi32, #blocked> %cst_0 = arith.constant dense<2> : tensor<1x2xi32, #blocked> %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #blocked1> - %1 = triton_gpu.convert_layout %0 : tensor<2xi32, #blocked1> -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x2xi32, #blocked> + %1 = ttg.convert_layout %0 : tensor<2xi32, #blocked1> -> tensor<2xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x2xi32, #blocked> %3 = arith.cmpi "slt", %2, %cst_0 : tensor<1x2xi32, #blocked> %4 = "tt.reduce" (%cst) ({ ^bb0(%arg3: i32, %arg4: i32): %add = arith.addi %arg3, %arg4 : i32 tt.reduce.return %add : i32 - }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = triton_gpu.convert_layout %4 : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> - %6 = triton_gpu.convert_layout %5 : tensor<1xi32, #blocked1> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> - %8 = triton_gpu.convert_layout %7 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> + }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = ttg.convert_layout %4 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %6 = ttg.convert_layout %5 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %8 = ttg.convert_layout %7 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> %9 = tt.splat %arg0 : !tt.ptr -> tensor<1x2x!tt.ptr, #blocked> %10 = tt.addptr %9, %2 : tensor<1x2x!tt.ptr, #blocked>, tensor<1x2xi32, #blocked> %11 = tt.broadcast %8 : tensor<1x1xi32, #blocked> -> tensor<1x2xi32, #blocked> %12 = arith.extsi %11 : tensor<1x2xi32, #blocked> to tensor<1x2xi64, #blocked> - %13 = triton_gpu.convert_layout %10 : tensor<1x2x!tt.ptr, #blocked> -> tensor<1x2x!tt.ptr, #blocked3> - %14 = triton_gpu.convert_layout %12 : tensor<1x2xi64, #blocked> -> tensor<1x2xi64, #blocked3> - %15 = triton_gpu.convert_layout %3 : tensor<1x2xi1, #blocked> -> tensor<1x2xi1, #blocked3> + %13 = ttg.convert_layout %10 : tensor<1x2x!tt.ptr, #blocked> -> tensor<1x2x!tt.ptr, #blocked3> + %14 = ttg.convert_layout %12 : tensor<1x2xi64, #blocked> -> tensor<1x2xi64, #blocked3> + %15 = ttg.convert_layout %3 : tensor<1x2xi1, #blocked> -> tensor<1x2xi1, #blocked3> tt.store %13, %14, %15 : tensor<1x2x!tt.ptr, #blocked3> tt.return } @@ -1211,19 +1211,19 @@ module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.num-ctas" = 1 : // CHECK-LABEL: reduce_cvt2 // Match the reduction -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.reduce // CHECK-SAME: axis = 1 -// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>> -// CHECK: triton_gpu.convert_layout +// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #{{.*}}}>> +// CHECK: ttg.convert_layout // CHECK: tt.expand_dims -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @reduce_cvt2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked> %c3136_i32 = arith.constant 3136 : index @@ -1237,15 +1237,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : %cst_5 = arith.constant dense<256> : tensor<1x1xi32, #blocked> %0 = tt.get_program_id x : i32 %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1> - %2 = triton_gpu.convert_layout %1 : tensor<1xi32, #blocked1> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> - %4 = triton_gpu.convert_layout %3 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> + %2 = ttg.convert_layout %1 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %4 = ttg.convert_layout %3 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> %5 = tt.splat %0 : i32 -> tensor<1x1xi32, #blocked> %6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked> %7 = arith.cmpi "slt", %6, %cst_5 : tensor<1x1xi32, #blocked> %8 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> - %9 = triton_gpu.convert_layout %8 : tensor<256xi32, #blocked1> -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %9 = ttg.convert_layout %8 : tensor<256xi32, #blocked1> -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> %11 = arith.muli %6, %cst_2 : tensor<1x1xi32, #blocked> %12 = tt.broadcast %11 : tensor<1x1xi32, #blocked> -> tensor<1x256xi32, #blocked> %13 = tt.splat %arg1 : !tt.ptr -> tensor<1x256x!tt.ptr, #blocked> @@ -1262,11 +1262,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : %50 = arith.addi %48, %49 : tensor<1x256xi32, #blocked> %51 = tt.addptr %13, %50 : tensor<1x256x!tt.ptr, #blocked>, tensor<1x256xi32, #blocked> %52 = arith.andi %45, %14 : tensor<1x256xi1, #blocked> - %53 = triton_gpu.convert_layout %51 : tensor<1x256x!tt.ptr, #blocked> -> tensor<1x256x!tt.ptr, #blocked3> - %54 = triton_gpu.convert_layout %52 : tensor<1x256xi1, #blocked> -> tensor<1x256xi1, #blocked3> - %55 = triton_gpu.convert_layout %cst : tensor<1x256xf32, #blocked> -> tensor<1x256xf32, #blocked3> + %53 = ttg.convert_layout %51 : tensor<1x256x!tt.ptr, #blocked> -> tensor<1x256x!tt.ptr, #blocked3> + %54 = ttg.convert_layout %52 : tensor<1x256xi1, #blocked> -> tensor<1x256xi1, #blocked3> + %55 = ttg.convert_layout %cst : tensor<1x256xf32, #blocked> -> tensor<1x256xf32, #blocked3> %56 = tt.load %53, %54, %55 : tensor<1x256x!tt.ptr, #blocked3> - %57 = triton_gpu.convert_layout %56 : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #blocked> + %57 = ttg.convert_layout %56 : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #blocked> %58 = arith.addf %arg6, %57 : tensor<1x256xf32, #blocked> %59 = arith.select %52, %58, %arg6 : tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked> scf.yield %59 : tensor<1x256xf32, #blocked> @@ -1276,17 +1276,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : %add = arith.addf %arg7, %arg8 : f32 tt.reduce.return %add : f32 - }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %17 = triton_gpu.convert_layout %16 : tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> - %18 = triton_gpu.convert_layout %17 : tensor<1xf32, #blocked1> -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %19 = tt.expand_dims %18 {axis = 1 : i32} : tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xf32, #blocked2> - %20 = triton_gpu.convert_layout %19 : tensor<1x1xf32, #blocked2> -> tensor<1x1xf32, #blocked> + }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = ttg.convert_layout %16 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %18 = ttg.convert_layout %17 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.expand_dims %18 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xf32, #blocked2> + %20 = ttg.convert_layout %19 : tensor<1x1xf32, #blocked2> -> tensor<1x1xf32, #blocked> %21 = arith.divf %20, %cst_0 : tensor<1x1xf32, #blocked> %22 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> %23 = tt.addptr %22, %6 : tensor<1x1x!tt.ptr, #blocked>, tensor<1x1xi32, #blocked> - %24 = triton_gpu.convert_layout %23 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1x!tt.ptr, #blocked> - %25 = triton_gpu.convert_layout %21 : tensor<1x1xf32, #blocked> -> tensor<1x1xf32, #blocked> - %26 = triton_gpu.convert_layout %7 : tensor<1x1xi1, #blocked> -> tensor<1x1xi1, #blocked> + %24 = ttg.convert_layout %23 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1x!tt.ptr, #blocked> + %25 = ttg.convert_layout %21 : tensor<1x1xf32, #blocked> -> tensor<1x1xf32, #blocked> + %26 = ttg.convert_layout %7 : tensor<1x1xi1, #blocked> -> tensor<1x1xi1, #blocked> tt.store %24, %25, %26 : tensor<1x1x!tt.ptr, #blocked> tt.return } @@ -1296,12 +1296,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // Ensure that RematerializeForward doesn't apply when a convert has multiple uses // CHECK-LABEL: loop_convert_multi_uses -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @loop_convert_multi_uses(%arg0: i32 {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0xFF800000> : tensor<16xf32, #blocked> %c1_i32 = arith.constant 1 : i32 @@ -1322,16 +1322,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %8 = arith.muli %2, %arg3 : i32 %9 = arith.muli %3, %arg4 : i32 %10 = arith.addi %8, %9 : i32 - %11 = triton_gpu.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %12 = tt.expand_dims %11 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> - %13 = triton_gpu.convert_layout %12 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> + %11 = ttg.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %12 = tt.expand_dims %11 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> + %13 = ttg.convert_layout %12 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> %14 = tt.splat %arg6 : i32 -> tensor<16x1xi32, #blocked1> %15 = arith.muli %13, %14 : tensor<16x1xi32, #blocked1> - %16 = triton_gpu.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %16 = ttg.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> %18 = tt.broadcast %15 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> %19 = tt.broadcast %17 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> - %20 = triton_gpu.convert_layout %19 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> + %20 = ttg.convert_layout %19 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> %21 = arith.addi %18, %20 : tensor<16x16xi32, #blocked1> %22 = tt.splat %arg2 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked1> %23 = arith.cmpi "slt", %13, %cst_3 : tensor<16x1xi32, #blocked1> @@ -1352,26 +1352,26 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %62 = tt.splat %61 : i32 -> tensor<16x16xi32, #blocked1> %63 = arith.addi %62, %21 : tensor<16x16xi32, #blocked1> %64 = tt.addptr %22, %63 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> - %65 = triton_gpu.convert_layout %64 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> - %66 = triton_gpu.convert_layout %24 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> - %67 = triton_gpu.convert_layout %25 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> + %65 = ttg.convert_layout %64 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> + %66 = ttg.convert_layout %24 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> + %67 = ttg.convert_layout %25 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> %68 = tt.load %65, %66, %67 : tensor<16x16x!tt.ptr, #blocked4> - %69 = triton_gpu.convert_layout %68 : tensor<16x16xf16, #blocked4> -> tensor<16x16xf16, #blocked1> + %69 = ttg.convert_layout %68 : tensor<16x16xf16, #blocked4> -> tensor<16x16xf16, #blocked1> %70 = arith.addi %28, %arg17 : i32 %71 = tt.splat %70 : i32 -> tensor<16xi32, #blocked> %72 = arith.addi %71, %7 : tensor<16xi32, #blocked> %73 = tt.addptr %29, %72 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> - %74 = triton_gpu.convert_layout %73 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> - %75 = triton_gpu.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> - %76 = triton_gpu.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> + %74 = ttg.convert_layout %73 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> + %75 = ttg.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> + %76 = ttg.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> %77 = tt.load %74, %75, %76 : tensor<16x!tt.ptr, #blocked> %78 = arith.addi %33, %arg17 : i32 %79 = tt.splat %78 : i32 -> tensor<16xi32, #blocked> %80 = arith.addi %79, %7 : tensor<16xi32, #blocked> %81 = tt.addptr %34, %80 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> - %82 = triton_gpu.convert_layout %81 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> - %83 = triton_gpu.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> - %84 = triton_gpu.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> + %82 = ttg.convert_layout %81 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> + %83 = ttg.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> + %84 = ttg.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> %85 = tt.load %82, %83, %84 : tensor<16x!tt.ptr, #blocked> %86 = arith.cmpf "ogt", %arg20, %85 : tensor<16xf32, #blocked> %87 = arith.select %86, %arg20, %85 : tensor<16xi1, #blocked>, tensor<16xf32, #blocked> @@ -1385,14 +1385,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %95 = arith.divf %91, %94 : tensor<16xf32, #blocked> %96 = arith.divf %arg19, %94 : tensor<16xf32, #blocked> %97 = arith.mulf %96, %89 : tensor<16xf32, #blocked> - %98 = triton_gpu.convert_layout %97 : tensor<16xf32, #blocked> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %99 = tt.expand_dims %98 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> - %100 = triton_gpu.convert_layout %99 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> + %98 = ttg.convert_layout %97 : tensor<16xf32, #blocked> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %99 = tt.expand_dims %98 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> + %100 = ttg.convert_layout %99 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> %101 = tt.broadcast %100 : tensor<16x1xf32, #blocked1> -> tensor<16x16xf32, #blocked1> %102 = arith.mulf %arg18, %101 : tensor<16x16xf32, #blocked1> - %103 = triton_gpu.convert_layout %95 : tensor<16xf32, #blocked> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> - %105 = triton_gpu.convert_layout %104 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> + %103 = ttg.convert_layout %95 : tensor<16xf32, #blocked> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> + %105 = ttg.convert_layout %104 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> %106 = tt.broadcast %105 : tensor<16x1xf32, #blocked1> -> tensor<16x16xf32, #blocked1> %107 = arith.extf %69 : tensor<16x16xf16, #blocked1> to tensor<16x16xf32, #blocked1> %108 = arith.mulf %107, %106 : tensor<16x16xf32, #blocked1> @@ -1402,16 +1402,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %36 = arith.muli %2, %arg14 : i32 %37 = arith.muli %3, %arg15 : i32 %38 = arith.addi %36, %37 : i32 - %39 = triton_gpu.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> - %41 = triton_gpu.convert_layout %40 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> + %39 = ttg.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> + %41 = ttg.convert_layout %40 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> %42 = tt.splat %arg16 : i32 -> tensor<16x1xi32, #blocked1> %43 = arith.muli %41, %42 : tensor<16x1xi32, #blocked1> - %44 = triton_gpu.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %45 = tt.expand_dims %44 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %44 = ttg.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %45 = tt.expand_dims %44 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> %46 = tt.broadcast %43 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> %47 = tt.broadcast %45 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> - %48 = triton_gpu.convert_layout %47 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> + %48 = ttg.convert_layout %47 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> %49 = arith.addi %46, %48 : tensor<16x16xi32, #blocked1> %50 = tt.splat %38 : i32 -> tensor<16x16xi32, #blocked1> %51 = arith.addi %50, %49 : tensor<16x16xi32, #blocked1> @@ -1420,9 +1420,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %54 = arith.cmpi "slt", %41, %cst_3 : tensor<16x1xi32, #blocked1> %55 = tt.broadcast %54 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> %56 = arith.truncf %35#0 : tensor<16x16xf32, #blocked1> to tensor<16x16xf16, #blocked1> - %57 = triton_gpu.convert_layout %53 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> - %58 = triton_gpu.convert_layout %56 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> - %59 = triton_gpu.convert_layout %55 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> + %57 = ttg.convert_layout %53 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> + %58 = ttg.convert_layout %56 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> + %59 = ttg.convert_layout %55 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> tt.store %57, %58, %59 : tensor<16x16x!tt.ptr, #blocked4> tt.return } @@ -1432,15 +1432,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // Check if MoveConvertOutOfLoop hangs because of adding additional conversions // CHECK-LABEL: @loop_print -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @loop_print(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c32_i32 = arith.constant 32 : i32 %c31_i32 = arith.constant 31 : i32 @@ -1450,25 +1450,25 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %cst_0 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> %cst_1 = arith.constant 0.000000e+00 : f32 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked2> - %1 = triton_gpu.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> %3 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> %4 = arith.muli %2, %3 : tensor<128x1xi32, #blocked1> %5 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked2> - %6 = triton_gpu.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %6 = ttg.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> %8 = tt.broadcast %4 : tensor<128x1xi32, #blocked1> -> tensor<128x32xi32, #blocked1> %9 = tt.broadcast %7 : tensor<1x32xi32, #blocked3> -> tensor<128x32xi32, #blocked3> - %10 = triton_gpu.convert_layout %9 : tensor<128x32xi32, #blocked3> -> tensor<128x32xi32, #blocked1> + %10 = ttg.convert_layout %9 : tensor<128x32xi32, #blocked3> -> tensor<128x32xi32, #blocked1> %11 = arith.addi %8, %10 : tensor<128x32xi32, #blocked1> - %12 = triton_gpu.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> - %14 = triton_gpu.convert_layout %13 : tensor<32x1xi32, #blocked1> -> tensor<32x1xi32, #blocked> - %15 = triton_gpu.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> + %12 = ttg.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %14 = ttg.convert_layout %13 : tensor<32x1xi32, #blocked1> -> tensor<32x1xi32, #blocked> + %15 = ttg.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> %17 = tt.broadcast %14 : tensor<32x1xi32, #blocked> -> tensor<32x128xi32, #blocked> %18 = tt.broadcast %16 : tensor<1x128xi32, #blocked3> -> tensor<32x128xi32, #blocked3> - %19 = triton_gpu.convert_layout %18 : tensor<32x128xi32, #blocked3> -> tensor<32x128xi32, #blocked> + %19 = ttg.convert_layout %18 : tensor<32x128xi32, #blocked3> -> tensor<32x128xi32, #blocked> %20 = arith.addi %17, %19 : tensor<32x128xi32, #blocked> %21 = arith.addi %arg5, %c31_i32 : i32 %22 = arith.divsi %21, %c32_i32 : i32 @@ -1477,19 +1477,19 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %25:3 = scf.for %arg7 = %c0_i32 to %22 step %c1_i32 iter_args(%arg8 = %cst_1, %arg9 = %11, %arg10 = %20) -> (f32, tensor<128x32xi32, #blocked1>, tensor<32x128xi32, #blocked>) : i32 { tt.print "a_offsets: " { hex = false, isSigned = array } : %arg9 : tensor<128x32xi32, #blocked1> %27 = tt.addptr %23, %arg9 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %28 = triton_gpu.convert_layout %27 : tensor<128x32x!tt.ptr, #blocked1> -> tensor<128x32x!tt.ptr, #blocked4> + %28 = ttg.convert_layout %27 : tensor<128x32x!tt.ptr, #blocked1> -> tensor<128x32x!tt.ptr, #blocked4> %29 = tt.load %28 : tensor<128x32x!tt.ptr, #blocked4> - %30 = triton_gpu.convert_layout %29 : tensor<128x32xf16, #blocked4> -> tensor<128x32xf16, #blocked1> + %30 = ttg.convert_layout %29 : tensor<128x32xf16, #blocked4> -> tensor<128x32xf16, #blocked1> %31 = tt.addptr %24, %arg10 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %32 = triton_gpu.convert_layout %31 : tensor<32x128x!tt.ptr, #blocked> -> tensor<32x128x!tt.ptr, #blocked5> + %32 = ttg.convert_layout %31 : tensor<32x128x!tt.ptr, #blocked> -> tensor<32x128x!tt.ptr, #blocked5> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked5> - %34 = triton_gpu.convert_layout %33 : tensor<32x128xf16, #blocked5> -> tensor<32x128xf16, #blocked> + %34 = ttg.convert_layout %33 : tensor<32x128xf16, #blocked5> -> tensor<32x128xf16, #blocked> %35 = "tt.reduce"(%30) <{axis = 0 : i32}> ({ ^bb0(%arg11: f16, %arg12: f16): %46 = arith.addf %arg11, %arg12 : f16 tt.reduce.return %46 : f16 - }) : (tensor<128x32xf16, #blocked1>) -> tensor<32xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %36 = triton_gpu.convert_layout %35 : tensor<32xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<32xf16, #blocked2> + }) : (tensor<128x32xf16, #blocked1>) -> tensor<32xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> + %36 = ttg.convert_layout %35 : tensor<32xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<32xf16, #blocked2> %37 = "tt.reduce"(%36) <{axis = 0 : i32}> ({ ^bb0(%arg11: f16, %arg12: f16): %46 = arith.addf %arg11, %arg12 : f16 @@ -1499,8 +1499,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war ^bb0(%arg11: f16, %arg12: f16): %46 = arith.addf %arg11, %arg12 : f16 tt.reduce.return %46 : f16 - }) : (tensor<32x128xf16, #blocked>) -> tensor<128xf16, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %39 = triton_gpu.convert_layout %38 : tensor<128xf16, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<128xf16, #blocked2> + }) : (tensor<32x128xf16, #blocked>) -> tensor<128xf16, #ttg.slice<{dim = 0, parent = #blocked}>> + %39 = ttg.convert_layout %38 : tensor<128xf16, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<128xf16, #blocked2> %40 = "tt.reduce"(%39) <{axis = 0 : i32}> ({ ^bb0(%arg11: f16, %arg12: f16): %46 = arith.addf %arg11, %arg12 : f16 @@ -1525,50 +1525,50 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK-LABEL: reduce_cvt3 // CHECK: tt.dot // CHECK-NEXT: tt.reduce -// CHECK: triton_gpu.convert_layout -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// CHECK: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @reduce_cvt3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> %cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked> %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked1> - %1 = triton_gpu.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xi32, #blocked2> - %3 = triton_gpu.convert_layout %2 : tensor<32x1xi32, #blocked2> -> tensor<32x1xi32, #blocked> + %1 = ttg.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<32x1xi32, #blocked2> -> tensor<32x1xi32, #blocked> %4 = arith.muli %3, %cst_0 : tensor<32x1xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %6 = tt.addptr %5, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> - %7 = triton_gpu.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %7 = ttg.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> %9 = tt.broadcast %6 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> %10 = tt.broadcast %8 : tensor<1x32xi32, #blocked3> -> tensor<32x32xi32, #blocked3> - %11 = triton_gpu.convert_layout %10 : tensor<32x32xi32, #blocked3> -> tensor<32x32xi32, #blocked> + %11 = ttg.convert_layout %10 : tensor<32x32xi32, #blocked3> -> tensor<32x32xi32, #blocked> %12 = tt.addptr %9, %11 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %13 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %14 = tt.addptr %13, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> %15 = tt.broadcast %14 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> %16 = tt.addptr %15, %11 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %17 = triton_gpu.convert_layout %12 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> + %17 = ttg.convert_layout %12 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> %18 = tt.load %17 : tensor<32x32x!tt.ptr, #blocked4> - %19 = triton_gpu.convert_layout %18 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> - %20 = triton_gpu.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> + %19 = ttg.convert_layout %18 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> + %20 = ttg.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> %21 = tt.load %20 : tensor<32x32x!tt.ptr, #blocked4> - %22 = triton_gpu.convert_layout %21 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> - %23 = triton_gpu.local_alloc %22 : (tensor<32x32xf16, #blocked>) -> !triton_gpu.memdesc<32x32xf16, #shared> - %24 = triton_gpu.memdesc_trans %23 {order=array} : !triton_gpu.memdesc<32x32xf16, #shared> -> !triton_gpu.memdesc<32x32xf16, #shared1> - %25 = triton_gpu.local_load %24 : !triton_gpu.memdesc<32x32xf16, #shared1> -> tensor<32x32xf16, #blocked> - %26 = triton_gpu.convert_layout %19 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> - %27 = triton_gpu.convert_layout %25 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> - %28 = triton_gpu.convert_layout %cst : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked5> - %29 = tt.dot %26, %27, %28 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> - %30 = triton_gpu.convert_layout %29 : tensor<32x32xf32, #blocked5> -> tensor<32x32xf32, #blocked> + %22 = ttg.convert_layout %21 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> + %23 = ttg.local_alloc %22 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<32x32xf16, #shared> + %24 = ttg.memdesc_trans %23 {order=array} : !ttg.memdesc<32x32xf16, #shared> -> !ttg.memdesc<32x32xf16, #shared1> + %25 = ttg.local_load %24 : !ttg.memdesc<32x32xf16, #shared1> -> tensor<32x32xf16, #blocked> + %26 = ttg.convert_layout %19 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked5}>> + %27 = ttg.convert_layout %25 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked5}>> + %28 = ttg.convert_layout %cst : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked5> + %29 = tt.dot %26, %27, %28 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> + %30 = ttg.convert_layout %29 : tensor<32x32xf32, #blocked5> -> tensor<32x32xf32, #blocked> %31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({ ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): %37 = arith.cmpf "oeq", %arg3, %arg5 : f32 @@ -1579,12 +1579,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %42 = arith.select %41, %arg3, %arg5 : f32 %43 = arith.select %41, %arg4, %arg6 : i32 tt.reduce.return %42, %43 : f32, i32 - }) : (tensor<32x32xf32, #blocked>, tensor<32x32xi32, #blocked>) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) - %32 = triton_gpu.convert_layout %31#1 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #blocked1> + }) : (tensor<32x32xf32, #blocked>, tensor<32x32xi32, #blocked>) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + %32 = ttg.convert_layout %31#1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #blocked1> %33 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #blocked1> %34 = tt.addptr %33, %0 : tensor<32x!tt.ptr, #blocked1>, tensor<32xi32, #blocked1> - %35 = triton_gpu.convert_layout %34 : tensor<32x!tt.ptr, #blocked1> -> tensor<32x!tt.ptr, #blocked1> - %36 = triton_gpu.convert_layout %32 : tensor<32xi32, #blocked1> -> tensor<32xi32, #blocked1> + %35 = ttg.convert_layout %34 : tensor<32x!tt.ptr, #blocked1> -> tensor<32x!tt.ptr, #blocked1> + %36 = ttg.convert_layout %32 : tensor<32xi32, #blocked1> -> tensor<32xi32, #blocked1> tt.store %35, %36 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -1594,20 +1594,20 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- // Check that we don't have extra convert for flash attention IR. -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked3a = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [4, 1, 8], warpsPerCTA = [4, 1, 1], order = [1, 2, 0]}> -#blocked4a = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [1, 4, 8], warpsPerCTA = [1, 4, 1], order = [0, 2, 1]}> -#blocked6a = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked6 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked7 = #triton_gpu.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [1, 1, 4], order = [1, 0, 2]}> -#blocked8 = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 1, 4], order = [0, 1, 2]}> -#blocked9 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3a = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [4, 1, 8], warpsPerCTA = [4, 1, 1], order = [1, 2, 0]}> +#blocked4a = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [1, 4, 8], warpsPerCTA = [1, 4, 1], order = [0, 2, 1]}> +#blocked6a = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked6 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked7 = #ttg.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [1, 1, 4], order = [1, 0, 2]}> +#blocked8 = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 1, 4], order = [0, 1, 2]}> +#blocked9 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @attention_fw(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %c0_i64 = arith.constant 0 : i64 %c64_i64 = arith.constant 64 : i64 @@ -1641,58 +1641,58 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3a> %20 = arith.extsi %19 : tensor<128xi32, #blocked3a> to tensor<128xi64, #blocked3a> %21 = arith.addi %18, %20 : tensor<128xi64, #blocked3a> - %22 = triton_gpu.convert_layout %21 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4a}>> - %23 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> + %22 = ttg.convert_layout %21 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> + %23 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> %24 = tt.splat %6 : i64 -> tensor<128x1xi64, #blocked4a> %25 = arith.muli %23, %24 : tensor<128x1xi64, #blocked4a> %26 = tt.broadcast %25 : tensor<128x1xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> - %27 = triton_gpu.convert_layout %26 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %27 = ttg.convert_layout %26 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> %28 = tt.addptr %17, %27 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> %29 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> %30 = arith.extsi %29 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> - %31 = triton_gpu.convert_layout %30 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked4a}>> - %32 = tt.expand_dims %31 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> + %31 = ttg.convert_layout %30 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> + %32 = tt.expand_dims %31 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> %33 = tt.broadcast %32 : tensor<1x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> - %34 = triton_gpu.convert_layout %33 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %34 = ttg.convert_layout %33 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> %35 = tt.addptr %28, %34 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> %36 = tt.load %35 : tensor<128x64x!tt.ptr, #blocked3> - %37 = triton_gpu.convert_layout %36 : tensor<128x64xf16, #blocked3> -> tensor<128x64xf16, #blocked2> + %37 = ttg.convert_layout %36 : tensor<128x64xf16, #blocked3> -> tensor<128x64xf16, #blocked2> %38 = tt.splat %16 : f32 -> tensor<128x64xf32, #blocked2> %39 = arith.extf %37 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %40 = arith.mulf %39, %38 : tensor<128x64xf32, #blocked2> %41 = arith.truncf %40 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.for -// CHECK-NOT: triton_gpu.convert_layout -// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op -// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK-NOT: ttg.convert_layout // CHECK: tt.dot -// CHECK-NOT: triton_gpu.convert_layout -// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op -// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK-NOT: ttg.convert_layout // CHECK: tt.dot // CHECK: scf.yield %42:5 = scf.for %arg22 = %c0_i32 to %9 step %c64_i32 iter_args(%arg23 = %cst_2, %arg24 = %cst_1, %arg25 = %cst_0, %arg26 = %c0_i64, %arg27 = %c0_i64) -> (tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64) : i32 { %78 = tt.splat %8 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked6> %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6a> %80 = arith.extsi %79 : tensor<64xi32, #blocked6a> to tensor<64xi64, #blocked6a> - %81 = triton_gpu.convert_layout %80 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %82 = tt.expand_dims %81 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<64x1xi64, #blocked6> + %81 = ttg.convert_layout %80 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked6}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked6}>> -> tensor<64x1xi64, #blocked6> %83 = tt.broadcast %82 : tensor<64x1xi64, #blocked6> -> tensor<64x64xi64, #blocked6> - %84 = triton_gpu.convert_layout %83 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> + %84 = ttg.convert_layout %83 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> %85 = tt.addptr %78, %84 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> %86 = tt.splat %arg26 : i64 -> tensor<64xi64, #blocked6a> %87 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6a> %88 = arith.extsi %87 : tensor<64xi32, #blocked6a> to tensor<64xi64, #blocked6a> %89 = arith.addi %86, %88 : tensor<64xi64, #blocked6a> - %90 = triton_gpu.convert_layout %89 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked6}>> - %91 = tt.expand_dims %90 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> + %90 = ttg.convert_layout %89 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> + %91 = tt.expand_dims %90 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> %92 = tt.splat %10 : i64 -> tensor<1x64xi64, #blocked6> %93 = arith.muli %91, %92 : tensor<1x64xi64, #blocked6> %94 = tt.broadcast %93 : tensor<1x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> - %95 = triton_gpu.convert_layout %94 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> + %95 = ttg.convert_layout %94 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> %96 = tt.addptr %85, %95 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> %97 = tt.load %96 : tensor<64x64x!tt.ptr, #blocked6> %98 = tt.splat %11 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked3> @@ -1700,69 +1700,69 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %100 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> %101 = arith.extsi %100 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> %102 = arith.addi %99, %101 : tensor<64xi64, #blocked3a> - %103 = triton_gpu.convert_layout %102 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi64, #blocked3> + %103 = ttg.convert_layout %102 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi64, #blocked3> %105 = tt.splat %12 : i64 -> tensor<64x1xi64, #blocked3> %106 = arith.muli %104, %105 : tensor<64x1xi64, #blocked3> %107 = tt.broadcast %106 : tensor<64x1xi64, #blocked3> -> tensor<64x64xi64, #blocked3> - %108 = triton_gpu.convert_layout %107 : tensor<64x64xi64, #blocked3> -> tensor<64x64xi64, #blocked3> + %108 = ttg.convert_layout %107 : tensor<64x64xi64, #blocked3> -> tensor<64x64xi64, #blocked3> %109 = tt.addptr %98, %108 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> %110 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> %111 = arith.extsi %110 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> - %112 = triton_gpu.convert_layout %111 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked4a}>> - %113 = tt.expand_dims %112 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> + %112 = ttg.convert_layout %111 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> + %113 = tt.expand_dims %112 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> %114 = tt.broadcast %113 : tensor<1x64xi64, #blocked4a> -> tensor<64x64xi64, #blocked4a> - %115 = triton_gpu.convert_layout %114 : tensor<64x64xi64, #blocked4a> -> tensor<64x64xi64, #blocked3> + %115 = ttg.convert_layout %114 : tensor<64x64xi64, #blocked4a> -> tensor<64x64xi64, #blocked3> %116 = tt.addptr %109, %115 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> %117 = tt.load %116 : tensor<64x64x!tt.ptr, #blocked3> - %118 = triton_gpu.convert_layout %41 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %119 = triton_gpu.convert_layout %97 : tensor<64x64xf16, #blocked6> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %120 = tt.dot %118, %119, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> - %121 = triton_gpu.convert_layout %120 : tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #blocked2> + %118 = ttg.convert_layout %41 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %119 = ttg.convert_layout %97 : tensor<64x64xf16, #blocked6> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %120 = tt.dot %118, %119, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> + %121 = ttg.convert_layout %120 : tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #blocked2> %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ ^bb0(%arg28: f32, %arg29: f32): %153 = arith.maximumf %arg28, %arg29 : f32 tt.reduce.return %153 : f32 - }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %124 = triton_gpu.convert_layout %123 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %124 = ttg.convert_layout %123 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> %125 = arith.maximumf %arg25, %124 : tensor<128xf32, #blocked1> %126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1> %127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> - %128 = triton_gpu.convert_layout %125 : tensor<128xf32, #blocked1> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %129 = tt.expand_dims %128 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> - %130 = triton_gpu.convert_layout %129 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %128 = ttg.convert_layout %125 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %129 = tt.expand_dims %128 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %130 = ttg.convert_layout %129 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> %131 = tt.broadcast %130 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> %132 = arith.subf %122, %131 : tensor<128x64xf32, #blocked2> %133 = tt.extern_elementwise %132 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> %134 = arith.mulf %arg24, %cst_1 : tensor<128xf32, #blocked1> %135 = arith.addf %134, %127 : tensor<128xf32, #blocked1> - %136 = triton_gpu.convert_layout %135 : tensor<128xf32, #blocked1> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %137 = tt.expand_dims %136 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> - %138 = triton_gpu.convert_layout %137 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %136 = ttg.convert_layout %135 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %137 = tt.expand_dims %136 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %138 = ttg.convert_layout %137 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> %139 = tt.broadcast %138 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> %140 = arith.mulf %arg23, %139 : tensor<128x64xf32, #blocked2> %141 = arith.truncf %133 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> - %142 = triton_gpu.convert_layout %141 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %143 = triton_gpu.convert_layout %117 : tensor<64x64xf16, #blocked3> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %144 = triton_gpu.convert_layout %140 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked> - %145 = tt.dot %142, %143, %144 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> - %146 = triton_gpu.convert_layout %145 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked2> + %142 = ttg.convert_layout %141 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %143 = ttg.convert_layout %117 : tensor<64x64xf16, #blocked3> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %144 = ttg.convert_layout %140 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked> + %145 = tt.dot %142, %143, %144 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> + %146 = ttg.convert_layout %145 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked2> %147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1> %148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({ ^bb0(%arg28: f32, %arg29: f32): %153 = arith.addf %arg28, %arg29 : f32 tt.reduce.return %153 : f32 - }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %149 = triton_gpu.convert_layout %148 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %149 = ttg.convert_layout %148 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> %150 = arith.addf %147, %149 : tensor<128xf32, #blocked1> %151 = arith.addi %arg26, %c64_i64 : i64 %152 = arith.addi %arg27, %c64_i64 : i64 scf.yield %146, %150, %125, %151, %152 : tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64 } - %43 = triton_gpu.convert_layout %42#1 : tensor<128xf32, #blocked1> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> - %45 = triton_gpu.convert_layout %44 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %43 = ttg.convert_layout %42#1 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %45 = ttg.convert_layout %44 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> %46 = tt.broadcast %45 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> %47 = arith.divf %42#0, %46 : tensor<128x64xf32, #blocked2> %48 = arith.muli %1, %arg20 : i32 @@ -1776,25 +1776,25 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %55 = arith.extsi %arg17 : i32 to i64 %56 = arith.extsi %5 : i32 to i64 %57 = arith.truncf %47 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> - %58 = triton_gpu.convert_layout %57 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #blocked3> + %58 = ttg.convert_layout %57 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #blocked3> %59 = tt.splat %54 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked3> %60 = tt.splat %56 : i64 -> tensor<128xi64, #blocked3a> %61 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3a> %62 = arith.extsi %61 : tensor<128xi32, #blocked3a> to tensor<128xi64, #blocked3a> %63 = arith.addi %60, %62 : tensor<128xi64, #blocked3a> - %64 = triton_gpu.convert_layout %63 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4a}>> - %65 = tt.expand_dims %64 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> + %64 = ttg.convert_layout %63 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> + %65 = tt.expand_dims %64 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> %66 = tt.splat %55 : i64 -> tensor<128x1xi64, #blocked4a> %67 = arith.muli %65, %66 : tensor<128x1xi64, #blocked4a> %68 = tt.broadcast %67 : tensor<128x1xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> - %69 = triton_gpu.convert_layout %68 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %69 = ttg.convert_layout %68 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> %70 = tt.addptr %59, %69 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> %71 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> %72 = arith.extsi %71 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> - %73 = triton_gpu.convert_layout %72 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked6}>> - %74 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> + %73 = ttg.convert_layout %72 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> + %74 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> %75 = tt.broadcast %74 : tensor<1x64xi64, #blocked6> -> tensor<128x64xi64, #blocked6> - %76 = triton_gpu.convert_layout %75 : tensor<128x64xi64, #blocked6> -> tensor<128x64xi64, #blocked3> + %76 = ttg.convert_layout %75 : tensor<128x64xi64, #blocked6> -> tensor<128x64xi64, #blocked3> %77 = tt.addptr %70, %76 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> tt.store %77, %58 : tensor<128x64x!tt.ptr, #blocked3> tt.return @@ -1803,37 +1803,37 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK-LABEL: axis_mismatch -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { -tt.func @axis_mismatch(%arg0: f32) -> tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @axis_mismatch(%arg0: f32) -> tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> { // CHECK: %[[R:.+]] = "tt.reduce"(%0) <{axis = 1 : i32}> -// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[R]] +// CHECK: %[[C:.+]] = ttg.convert_layout %[[R]] // CHECK: tt.return %[[C]] %0 = tt.splat %arg0 : f32 -> tensor<1x16xf32, #blocked> %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ ^bb0(%arg9: f32, %arg10: f32): %60 = arith.addf %arg9, %arg10 : f32 tt.reduce.return %60 : f32 - }) : (tensor<1x16xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = triton_gpu.convert_layout %1 : tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> - %3 = triton_gpu.convert_layout %2 : tensor<1xf32, #blocked1> -> tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - tt.return %3: tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + }) : (tensor<1x16xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = ttg.convert_layout %1 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %3 = ttg.convert_layout %2 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> + tt.return %3: tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: reduce_to_scalar -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return tt.func @reduce_to_scalar(%ptr: tensor<1024x!tt.ptr, #blocked>) -> (f32, i32) { %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> - %1 = triton_gpu.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked1> %3:2 = "tt.reduce"(%1, %2) <{axis = 0 : i32}> ({ ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): @@ -1852,9 +1852,9 @@ tt.func @reduce_to_scalar(%ptr: tensor<1024x!tt.ptr, #blocked>) -> (f32, i3 // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: whileop // CHECK: %[[L:.+]] = tt.load %{{.*}} : tensor<1024x!tt.ptr, #blocked> // CHECK: %[[W:.+]] = scf.while (%[[I:.+]] = %[[L]], %{{.*}} = %{{.*}}) : (tensor<1024xf32, #blocked>, i1) -> tensor<1024xf32, #blocked> { @@ -1867,17 +1867,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // CHECK: tt.store %{{.*}}, %[[W]] : tensor<1024x!tt.ptr, #blocked> tt.func @whileop(%ptr: tensor<1024x!tt.ptr, #blocked>, %cond: i1) { %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> - %1 = triton_gpu.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> %2 = scf.while (%arg0 = %1, %arg1 = %cond) : (tensor<1024xf32, #blocked1>, i1) -> (tensor<1024xf32, #blocked1>) { scf.condition(%arg1) %arg0 : tensor<1024xf32, #blocked1> } do { ^bb0(%arg0: tensor<1024xf32, #blocked1>): - %4 = triton_gpu.convert_layout %arg0 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> + %4 = ttg.convert_layout %arg0 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> %5 = arith.addf %4, %4 : tensor<1024xf32, #blocked> - %6 = triton_gpu.convert_layout %5 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + %6 = ttg.convert_layout %5 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> scf.yield %6, %cond : tensor<1024xf32, #blocked1>, i1 } - %3 = triton_gpu.convert_layout %2 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> + %3 = ttg.convert_layout %2 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> tt.store %ptr, %3 : tensor<1024x!tt.ptr, #blocked> tt.return } @@ -1898,7 +1898,7 @@ tt.func @whileop(%ptr: tensor<1024x!tt.ptr, #blocked>, %cond: i1) { // Check that we don't transform this loop into `yield %x` on the incorrect // theory that the yield is dead unless %x = %y. -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL @yield_outside_loop1 tt.func public @yield_outside_loop1(%arg0: i32, %arg1: i32) -> (i32) { @@ -1939,16 +1939,16 @@ tt.func public @yield_outside_loop2(%arg0: i32, %arg1: i32) -> (i32, i32) { // Check that we handle corner cases when hoisting conversions on top of extf because conversion operations on a smaller type are faster. // For complex slices we may hoist convert on top of extf while the source of extf has multiple uses in the slice. // In this case we want to make sure we don't replace other uses of extf source. -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK: [[$BLOCKED:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -// CHECK: [[$MMA:#.*]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK: [[$BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK: [[$MMA:#.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> // CHECK-LABEL: @hoist_convert_above_extf_and_remat tt.func public @hoist_convert_above_extf_and_remat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr) attributes {noinline = false} { @@ -1958,24 +1958,24 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %c64_i32 = arith.constant 64 : i32 %c256_i32 = arith.constant 256 : i32 %c0_i32 = arith.constant 0 : i32 - %cst_2 = arith.constant dense<1.000000e-03> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %cst_3 = arith.constant dense<2.560000e+02> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %cst_2 = arith.constant dense<1.000000e-03> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %cst_3 = arith.constant dense<2.560000e+02> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x256xf32, #blocked3> %c32_i32 = arith.constant 32 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %4 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked> %5 = arith.addi %4, %3 : tensor<32x1xi32, #blocked> %6 = arith.muli %5, %cst : tensor<32x1xi32, #blocked> - %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> %11 = tt.broadcast %9 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> %14 = arith.muli %13, %cst_1 : tensor<256x1xi32, #blocked> %15 = tt.broadcast %10 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> %16 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> @@ -1993,29 +1993,29 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %67 = tt.load %66 : tensor<32x64x!tt.ptr, #blocked> %68 = tt.addptr %17, %65 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> %69 = tt.load %68 : tensor<256x64x!tt.ptr, #blocked> - %70 = triton_gpu.local_alloc %69 : (tensor<256x64xf16, #blocked>) -> !triton_gpu.memdesc<256x64xf16, #shared> - %71 = triton_gpu.memdesc_trans %70 {order=array} : !triton_gpu.memdesc<256x64xf16, #shared> -> !triton_gpu.memdesc<64x256xf16, #shared1> - %72 = triton_gpu.convert_layout %67 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> - %73 = triton_gpu.local_load %71 : !triton_gpu.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> - %74 = triton_gpu.convert_layout %arg8 : tensor<32x256xf32, #blocked3> -> tensor<32x256xf32, #mma> - %75 = triton_gpu.convert_layout %72 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %76 = triton_gpu.convert_layout %73 : tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %77 = tt.dot %75, %76, %74 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> - %78 = triton_gpu.convert_layout %77 : tensor<32x256xf32, #mma> -> tensor<32x256xf32, #blocked3> + %70 = ttg.local_alloc %69 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared> + %71 = ttg.memdesc_trans %70 {order=array} : !ttg.memdesc<256x64xf16, #shared> -> !ttg.memdesc<64x256xf16, #shared1> + %72 = ttg.convert_layout %67 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> + %73 = ttg.local_load %71 : !ttg.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> + %74 = ttg.convert_layout %arg8 : tensor<32x256xf32, #blocked3> -> tensor<32x256xf32, #mma> + %75 = ttg.convert_layout %72 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %76 = ttg.convert_layout %73 : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %77 = tt.dot %75, %76, %74 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> + %78 = ttg.convert_layout %77 : tensor<32x256xf32, #mma> -> tensor<32x256xf32, #blocked3> scf.yield %78 : tensor<32x256xf32, #blocked3> } %19 = arith.truncf %18 : tensor<32x256xf32, #blocked3> to tensor<32x256xf16, #blocked3> - %20 = triton_gpu.convert_layout %19 : tensor<32x256xf16, #blocked3> -> tensor<32x256xf16, #blocked2> - %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %24 = tt.expand_dims %22 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %20 = ttg.convert_layout %19 : tensor<32x256xf16, #blocked3> -> tensor<32x256xf16, #blocked2> + %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %24 = tt.expand_dims %22 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> %25 = tt.splat %arg2 : !tt.ptr -> tensor<1x256x!tt.ptr, #blocked2> %26 = tt.addptr %25, %23 : tensor<1x256x!tt.ptr, #blocked2>, tensor<1x256xi32, #blocked2> %27 = tt.load %26 : tensor<1x256x!tt.ptr, #blocked2> %28 = tt.broadcast %27 : tensor<1x256xf16, #blocked2> -> tensor<32x256xf16, #blocked2> %29 = arith.addf %20, %28 : tensor<32x256xf16, #blocked2> -// CHECK: %[[A:.+]] = triton_gpu.convert_layout {{.*}} : tensor<1x256xf16, [[$BLOCKED]]> -> tensor<1x256xf16, [[$MMA]]> +// CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<1x256xf16, [[$BLOCKED]]> -> tensor<1x256xf16, [[$MMA]]> // CHECK: %[[B:.+]] = tt.broadcast %[[A]] // CHECK: %[[C:.+]] = arith.addf %[[B:.+]], {{.*}} // CHECK: arith.extf %[[C]] : tensor<32x256xf16, [[$MMA]]> to tensor<32x256xf32, [[$MMA]]> @@ -2024,28 +2024,28 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg7: f32, %arg8: f32): %58 = arith.addf %arg7, %arg8 : f32 tt.reduce.return %58 : f32 - }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %32 = arith.divf %31, %cst_3 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %32 = arith.divf %31, %cst_3 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> %33 = arith.mulf %30, %30 : tensor<32x256xf32, #blocked2> %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ ^bb0(%arg7: f32, %arg8: f32): %58 = arith.addf %arg7, %arg8 : f32 tt.reduce.return %58 : f32 - }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %35 = arith.divf %34, %cst_3 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %36 = arith.mulf %32, %32 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %37 = arith.subf %35, %36 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %38 = math.sqrt %37 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %39 = arith.addf %38, %cst_2 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %40 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> - %41 = tt.expand_dims %39 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> + }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %35 = arith.divf %34, %cst_3 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %36 = arith.mulf %32, %32 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %37 = arith.subf %35, %36 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %38 = math.sqrt %37 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %39 = arith.addf %38, %cst_2 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %40 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> + %41 = tt.expand_dims %39 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> %42 = tt.broadcast %40 : tensor<32x1xf32, #blocked2> -> tensor<32x256xf32, #blocked2> %43 = arith.subf %30, %42 : tensor<32x256xf32, #blocked2> %44 = tt.broadcast %41 : tensor<32x1xf32, #blocked2> -> tensor<32x256xf32, #blocked2> %45 = arith.divf %43, %44 : tensor<32x256xf32, #blocked2> %46 = arith.truncf %45 : tensor<32x256xf32, #blocked2> to tensor<32x256xf16, #blocked2> - %47 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %47 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> %49 = arith.muli %48, %cst_0 : tensor<32x1xi32, #blocked1> %50 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked1> %51 = arith.addi %50, %49 : tensor<32x1xi32, #blocked1> @@ -2054,7 +2054,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %54 = arith.addi %52, %53 : tensor<32x256xi32, #blocked1> %55 = tt.splat %arg5 : !tt.ptr -> tensor<32x256x!tt.ptr, #blocked1> %56 = tt.addptr %55, %54 : tensor<32x256x!tt.ptr, #blocked1>, tensor<32x256xi32, #blocked1> - %57 = triton_gpu.convert_layout %46 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #blocked1> + %57 = ttg.convert_layout %46 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #blocked1> tt.store %56, %57 : tensor<32x256x!tt.ptr, #blocked1> tt.return } @@ -2062,60 +2062,60 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @backward_reduce_multiple_results -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return - tt.func public @backward_reduce_multiple_results() -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> { + tt.func public @backward_reduce_multiple_results() -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> { %cst = arith.constant dense<0xFFF0000000000000> : tensor<1x32xf64, #blocked1> - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2> - %2 = triton_gpu.convert_layout %1 : tensor<1x32xi32, #blocked2> -> tensor<1x32xi32, #blocked1> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<1x32xi32, #blocked2> -> tensor<1x32xi32, #blocked1> %3:2 = "tt.reduce"(%cst, %2) <{axis = 1 : i32}> ({ ^bb0(%arg0: f64, %arg1: i32, %arg2: f64, %arg3: i32): %5 = arith.addi %arg1, %arg3 : i32 %6 = arith.addf %arg0, %arg2 : f64 tt.reduce.return %6, %5 : f64, i32 - }) : (tensor<1x32xf64, #blocked1>, tensor<1x32xi32, #blocked1>) -> (tensor<1xf64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>, tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) - %4 = triton_gpu.convert_layout %3#1 : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - tt.return %4 : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<1x32xf64, #blocked1>, tensor<1x32xi32, #blocked1>) -> (tensor<1xf64, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>) + %4 = ttg.convert_layout %3#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + tt.return %4 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> } } // end module // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @reshape_propagate tt.func public @reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked3> { - // CHECK-NOT: triton_gpu.convert_layout - %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + // CHECK-NOT: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> - %c = triton_gpu.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> + %c = ttg.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> tt.return %c : tensor<32xf32, #blocked3> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @reshape_sink_convert tt.func public @reshape_sink_convert(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked2> { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: tt.reshape - // CHECK: triton_gpu.convert_layout - %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + // CHECK: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> tt.return %b : tensor<32xf32, #blocked2> } @@ -2123,18 +2123,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @permuting_reshape_propagate tt.func public @permuting_reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf16, #blocked2> { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: arith.truncf - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout %a = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> - %b = triton_gpu.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = ttg.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = arith.truncf %b : tensor<32xf32, #blocked2> to tensor<32xf16, #blocked2> tt.return %c : tensor<32xf16, #blocked2> } @@ -2142,24 +2142,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: scan_propagation tt.func @scan_propagation(%arg: tensor<1024xi32, #slice1dim1>) -> tensor<1024xi32, #slice1dim1> { - %1 = triton_gpu.convert_layout %arg : tensor<1024xi32, #slice1dim1> -> tensor<1024xi32, #blocked2> + %1 = ttg.convert_layout %arg : tensor<1024xi32, #slice1dim1> -> tensor<1024xi32, #blocked2> %2 = "tt.scan" (%1) ({ ^bb0(%arg3: i32, %arg4: i32): %add = arith.addi %arg3, %arg4 : i32 tt.scan.return %add : i32 }) {axis = 1 : i32, reverse = false} : (tensor<1024xi32, #blocked2>) -> tensor<1024xi32, #blocked2> - %3 = triton_gpu.convert_layout %2 : tensor<1024xi32, #blocked2> -> tensor<1024xi32, #slice1dim1> + %3 = ttg.convert_layout %2 : tensor<1024xi32, #blocked2> -> tensor<1024xi32, #slice1dim1> // don't allow non blocked layout to be propagated to scan - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.scan - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.return tt.return %3: tensor<1024xi32, #slice1dim1> } @@ -2167,22 +2167,22 @@ tt.func @scan_propagation(%arg: tensor<1024xi32, #slice1dim1>) -> tensor<1024xi3 // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: fw_propagate_for_op tt.func public @fw_propagate_for_op(%arg0: tensor<1024x4xi32, #blocked>, %arg1: tensor<1024x4x!tt.ptr, #blocked1>) { %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: arith.muli // CHECK: scf.for // CHECK: scf.yield - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.store - %0 = triton_gpu.convert_layout %arg0 : tensor<1024x4xi32, #blocked> -> tensor<1024x4xi32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<1024x4xi32, #blocked> -> tensor<1024x4xi32, #blocked1> %1 = arith.muli %0, %0 : tensor<1024x4xi32, #blocked1> %2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %1) -> (tensor<1024x4xi32, #blocked1>) : i32 { %3 = arith.addi %arg3, %arg3 : tensor<1024x4xi32, #blocked1> @@ -2195,16 +2195,16 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @rematerialize_through_if tt.func public @rematerialize_through_if(%arg0: i1, %arg1: f32) -> tensor<32xf32, #blocked> { // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> // CHECK: scf.if %arg0 -> (tensor<32xf32, #blocked>) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked1> %0 = tt.splat %arg1 : f32 -> tensor<32xf32, #blocked1> @@ -2215,30 +2215,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : %2 = arith.addf %cst_0, %0 : tensor<32xf32, #blocked1> scf.yield %2 : tensor<32xf32, #blocked1> } - %4 = triton_gpu.convert_layout %3 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %4 = ttg.convert_layout %3 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> tt.return %4 : tensor<32xf32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @rematerialize_if_inside_loop tt.func public @rematerialize_if_inside_loop() -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) { // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: %[[for:[0-9]*]]:2 = scf.for {{.*}} -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: scf.if %{{.*}} -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: scf.yield {{.*}} : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> // CHECK: scf.yield {{.*}} : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %[[for]]#1, %[[for]]#0 %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked> @@ -2251,25 +2251,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : scf.yield %cst, %cst_0 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } else { %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> scf.yield %4, %6 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } scf.yield %3#0, %3#1 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } - %7 = triton_gpu.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> tt.return %7, %1#1 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: rematerialize_loop_arg tt.func public @rematerialize_loop_arg(%arg0: !tt.ptr) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 %c128_i32 = arith.constant 128 : i32 @@ -2278,14 +2278,14 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %cst_2 = arith.constant dense<128> : tensor<128x64xi32, #blocked> %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %0) -> (tensor<128x64x!tt.ptr, #blocked>) - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: scf.yield %{{.*}} : tensor<128x64x!tt.ptr, #blocked> %1 = scf.for %arg1 = %c0_i32 to %c128_i32 step %c1_i32 iter_args(%arg2 = %0) -> (tensor<128x64x!tt.ptr, #blocked>) : i32 { %2 = tt.addptr %arg2, %cst_1 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %3 = triton_gpu.convert_layout %2 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> + %3 = ttg.convert_layout %2 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> tt.store %3, %cst_0 : tensor<128x64x!tt.ptr, #blocked1> %4 = tt.addptr %arg2, %cst_2 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %5 = triton_gpu.convert_layout %4 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> + %5 = ttg.convert_layout %4 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> tt.store %5, %cst_0 : tensor<128x64x!tt.ptr, #blocked1> scf.yield %2 : tensor<128x64x!tt.ptr, #blocked> } @@ -2296,50 +2296,50 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: assertop // CHECK: %[[L:.+]] = tt.load %{{.*}} : tensor<1024x!tt.ptr, #blocked> // CHECK: tt.assert %[[L]] tt.func @assertop(%ptr: tensor<1024x!tt.ptr, #blocked>) { %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> - %1 = triton_gpu.convert_layout %0 : tensor<1024xi1, #blocked> -> tensor<1024xi1, #blocked1> + %1 = ttg.convert_layout %0 : tensor<1024xi1, #blocked> -> tensor<1024xi1, #blocked1> tt.assert %1, "cond must be true " : tensor<1024xi1, #blocked1> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @warp_group_dot_wait_propagate tt.func public @warp_group_dot_wait_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<16x2xf32, #blocked> { - // CHECK-NOT: triton_gpu.convert_layout - %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = triton_nvidia_gpu.warp_group_dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> - %c = triton_gpu.convert_layout %b : tensor<16x2xf32, #blocked1> -> tensor<16x2xf32, #blocked> + // CHECK-NOT: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + %b = ttng.warp_group_dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> + %c = ttg.convert_layout %b : tensor<16x2xf32, #blocked1> -> tensor<16x2xf32, #blocked> tt.return %c : tensor<16x2xf32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2,4], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4,2], threadsPerWarp = [2,16], warpsPerCTA = [1,1], order = [0,1]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2,4], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4,2], threadsPerWarp = [2,16], warpsPerCTA = [1,1], order = [0,1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @trans_propagate tt.func public @trans_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<2x16xf32, #blocked2> { // CHECK: tt.trans - // CHECK: triton_gpu.convert_layout - %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + // CHECK: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> %b = tt.trans %a {order=array} : tensor<16x2xf32, #blocked1> -> tensor<2x16xf32, #blocked2> tt.return %b : tensor<2x16xf32, #blocked2> } @@ -2347,34 +2347,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // Verify that we don't hoist the convert on top of the broadcast. In general we should hoist the convert to reduce its cost // but because this would combine the 1st and 2nd convert and since the 1st convert is known to be a no-op this would // generate more expensive code. // CHECK-LABEL: @hoist_with_free_convert tt.func public @hoist_with_free_convert(%arg0: tensor<128x256xf32, #mma1>, %arg1: tensor<128x1xf32, #mma>) -> tensor<128x256xf32, #blocked> { - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.broadcast - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.return - %0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf32, #mma1> -> tensor<128x256xf32, #mma> + %0 = ttg.convert_layout %arg0 : tensor<128x256xf32, #mma1> -> tensor<128x256xf32, #mma> %1 = tt.broadcast %arg1 : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> %2 = arith.addf %0, %1 : tensor<128x256xf32, #mma> - %3 = triton_gpu.convert_layout %2 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> + %3 = ttg.convert_layout %2 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> tt.return %3 : tensor<128x256xf32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @rematerialize_loop_arg tt.func public @rematerialize_loop_arg() -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1>) { %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> @@ -2390,11 +2390,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: tt.return %[[F]]#3, %[[F]]#1, %[[F]]#2 %1:3 = scf.for %arg0 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg1 = %cst, %arg3 = %cst_0, %arg4 = %cst) -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1>) : i32 { %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> scf.yield %4, %6, %4 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1> } - %7 = triton_gpu.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> tt.return %7, %1#1, %1#2 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1> } @@ -2402,22 +2402,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // Regression test: // Rematerialization of multiple loop-carried variables, where one is // rematerialized to the same layout by multiple users. // Previously this didn't interact correctly with the de-duplication mechanism. // CHECK-LABEL: @multi_rematerialize_loop_arg - tt.func public @multi_rematerialize_loop_arg(%arg0: !tt.ptr, %arg1: !tt.ptr) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + tt.func public @multi_rematerialize_loop_arg(%arg0: !tt.ptr, %arg1: !tt.ptr) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) { %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 %c2048_i32 = arith.constant 2048 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %cst_1 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> @@ -2425,59 +2425,59 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> - // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) - // CHECK: scf.yield {{.*}} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) + // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> // CHECK: } - // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %6 = tt.load %2 : tensor<64x64x!tt.ptr, #blocked2> - %7 = triton_gpu.convert_layout %1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %8 = triton_gpu.convert_layout %6 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %9 = tt.dot %7, %8, %cst_2, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %7 = ttg.convert_layout %1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %8 = ttg.convert_layout %6 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %9 = tt.dot %7, %8, %cst_2, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> %10 = tt.load %3 : tensor<128x64x!tt.ptr, #blocked> %11 = tt.load %4 : tensor<128x64x!tt.ptr, #blocked> %12 = arith.cmpi eq, %10, %11 : tensor<128x64xi8, #blocked> - %13 = triton_gpu.convert_layout %12 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #mma> + %13 = ttg.convert_layout %12 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #mma> %14 = arith.select %13, %9, %cst_1 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> - %15 = triton_gpu.convert_layout %14 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %15 = ttg.convert_layout %14 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> %16 = "tt.reduce"(%15) <{axis = 1 : i32}> ({ ^bb0(%arg6: f32, %arg7: f32): %34 = arith.maxnumf %arg6, %arg7 : f32 tt.reduce.return %34 : f32 - }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %17 = arith.maxnumf %arg5, %16 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %18 = arith.cmpf oeq, %17, %cst_0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %19 = triton_gpu.convert_layout %18 : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %20 = arith.select %18, %cst, %17 : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi1, #mma> + }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = arith.maxnumf %arg5, %16 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %18 = arith.cmpf oeq, %17, %cst_0 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %19 = ttg.convert_layout %18 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>> + %20 = arith.select %18, %cst, %17 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi1, #mma> %22 = tt.broadcast %21 : tensor<128x1xi1, #mma> -> tensor<128x64xi1, #mma> %23 = arith.select %22, %cst_2, %14 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> - %24 = triton_gpu.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> - %25 = arith.mulf %arg4, %cst : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %26 = triton_gpu.convert_layout %25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %27 = tt.expand_dims %26 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %24 = ttg.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %25 = arith.mulf %arg4, %cst : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %26 = ttg.convert_layout %25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %27 = tt.expand_dims %26 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> %28 = tt.broadcast %27 : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma> %29 = arith.mulf %arg3, %28 : tensor<128x64xf32, #mma> - %30 = triton_gpu.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %31 = arith.mulf %arg4, %20 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %30 = ttg.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %31 = arith.mulf %arg4, %20 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %32 = "tt.reduce"(%24) <{axis = 1 : i32}> ({ ^bb0(%arg6: f32, %arg7: f32): %34 = arith.addf %arg6, %arg7 : f32 tt.reduce.return %34 : f32 - }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %33 = arith.addf %31, %32 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %29, %33, %17 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %33 = arith.addf %31, %32 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %29, %33, %17 : tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - tt.return %5#1, %5#2 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + tt.return %5#1, %5#2 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } } // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked7 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked7 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // Regression test: // The while loop use the result of the for loop as an argument. // When propagating the layout, we should only "forward" propagate the layout to the argument and the result of the while loop @@ -2495,25 +2495,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %74 = tt.load %1000 : tensor<256x64x!tt.ptr, #blocked2> %67:2 = scf.for %arg11 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg12 = %cst_0, %arg14 = %1001) -> (tensor<256x128xf32, #blocked1>, tensor<64x128x!tt.ptr, #blocked1>) : i32 { %76 = tt.load %arg14 : tensor<64x128x!tt.ptr, #blocked1> - %78 = triton_gpu.convert_layout %74 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> - %79 = triton_gpu.convert_layout %76 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> - %80 = triton_gpu.convert_layout %arg12 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> - %81 = tt.dot %78, %79, %80, inputPrecision = tf32 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> - %82 = triton_gpu.convert_layout %81 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> + %78 = ttg.convert_layout %74 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked7}>> + %79 = ttg.convert_layout %76 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked7}>> + %80 = ttg.convert_layout %arg12 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> + %81 = tt.dot %78, %79, %80, inputPrecision = tf32 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> + %82 = ttg.convert_layout %81 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> scf.yield %82, %arg14 : tensor<256x128xf32, #blocked1>, tensor<64x128x!tt.ptr, #blocked1> } %68:2 = scf.while (%arg11 = %67#0, %arg12 = %c1_i32) : (tensor<256x128xf32, #blocked1>, i32) -> (tensor<256x128xf32, #blocked1>, i32) { scf.condition(%c0_i1) %arg11, %arg12 : tensor<256x128xf32, #blocked1>, i32 } do { ^bb0(%arg11: tensor<256x128xf32, #blocked1>, %arg12: i32): - %80 = triton_gpu.convert_layout %1003 : tensor<256x128x!tt.ptr, #blocked1> -> tensor<256x128x!tt.ptr, #blocked1> + %80 = ttg.convert_layout %1003 : tensor<256x128x!tt.ptr, #blocked1> -> tensor<256x128x!tt.ptr, #blocked1> %81 = tt.load %80 : tensor<256x128x!tt.ptr, #blocked1> %82 = arith.addf %arg11, %81 : tensor<256x128xf32, #blocked1> %83 = arith.addi %arg12, %c1_i32 : i32 scf.yield %82, %83 : tensor<256x128xf32, #blocked1>, i32 } %69 = arith.truncf %68#0 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> - %71 = triton_gpu.convert_layout %69 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> + %71 = ttg.convert_layout %69 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> tt.store %1002, %71 : tensor<256x128x!tt.ptr, #blocked1> tt.return } @@ -2524,32 +2524,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that backward rematerialization bails out when the same tensor requires two different layouts // CHECK-LABEL: double_remat -// CHECK: %[[res:.*]] = triton_gpu.convert_layout +// CHECK: %[[res:.*]] = ttg.convert_layout // CHECK-NEXT: tt.return %[[res]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 2], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 2], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:86", "ttg.threads-per-warp" = 32 : i32} { tt.func public @double_remat() -> tensor<1x256xi32, #blocked> attributes {noinline = false} { %cst = arith.constant dense<0> : tensor<1x256xi32, #blocked1> - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> - %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> %3 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<1x2x128xi32, #blocked2> %4 = tt.reshape %3 : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1> %5 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x64xi32, #blocked2> %6 = tt.reshape %5 : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1> %7 = arith.cmpi ne, %4, %cst : tensor<1x256xi32, #blocked1> %8 = arith.select %7, %6, %cst : tensor<1x256xi1, #blocked1>, tensor<1x256xi32, #blocked1> - %9 = triton_gpu.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked> + %9 = ttg.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked> tt.return %9 : tensor<1x256xi32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @if_condition_not_dead_inside_loop // CHECK: scf.if // CHECK-NOT: convert_layout @@ -2565,44 +2565,44 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : scf.yield %cst, %cst_0 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } else { %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> scf.yield %4, %6 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } %119 = arith.cmpi eq, %arg10, %arg0 : i32 scf.yield %3#0, %3#1, %119 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, i1 } - %7 = triton_gpu.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> tt.return %7, %1#1 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @dot_wait tt.func public @dot_wait(%arg0: tensor<64x64xf32, #mma>, %arg1: tensor<64x128xf32, #mma1>) -> (tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1>) { - %0:2 = triton_nvidia_gpu.warp_group_dot_wait %arg0, %arg1 {pendings = 0 : i32} : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> + %0:2 = ttng.warp_group_dot_wait %arg0, %arg1 {pendings = 0 : i32} : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> tt.return %0#0, %0#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> - // CHECK: %[[W:.+]]:2 = triton_nvidia_gpu.warp_group_dot_wait + // CHECK: %[[W:.+]]:2 = ttng.warp_group_dot_wait // CHECK: tt.return %[[W]]#0, %[[W]]#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @split_propagation // CHECK-SAME: (%[[ARG:.+]]: tensor<128x64x2xf32 // CHECK: %[[S:.+]], %{{.+}} = tt.split %[[ARG]] - // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[S]] + // CHECK: %[[C:.+]] = ttg.convert_layout %[[S]] // CHECK: tt.return %[[C]] tt.func public @split_propagation(%arg0: tensor<128x64x2xf32, #blocked>) -> tensor<128x64xf32, #blocked1> { - %0 = triton_gpu.convert_layout %arg0 : tensor<128x64x2xf32, #blocked> -> tensor<128x64x2xf32, #blocked2> + %0 = ttg.convert_layout %arg0 : tensor<128x64x2xf32, #blocked> -> tensor<128x64x2xf32, #blocked2> %outLHS, %outRHS = tt.split %0 : tensor<128x64x2xf32, #blocked2> -> tensor<128x64xf32, #blocked1> tt.return %outLHS : tensor<128x64xf32, #blocked1> } @@ -2610,14 +2610,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#CL = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: matmul_add tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> @@ -2630,11 +2630,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : %100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL>) { %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> - %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> + %t = ttg.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> // CHECK: %[[T0:.*]] = tt.dot // CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma> %t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL> @@ -2644,7 +2644,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL> } - // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked + // CHECK: ttg.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr, #CL> tt.return } @@ -2658,29 +2658,29 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // // CHECK-LABEL: small_tensor_mfma -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}> -#mma1 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}> +#mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @small_tensor_mfma(%arg0: !tt.ptr) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %cst_2 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> %cst_3 = arith.constant dense<1.230000e+02> : tensor<32x16xf32, #mma1> - %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %2 = "tt.reduce" (%1) ({ ^bb0(%arg1: f32, %arg2: f32): %3 = arith.addf %arg1, %arg2 : f32 tt.reduce.return %3 : f32 - }) {axis = 1 : i32} : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> + }) {axis = 1 : i32} : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> %5 = tt.broadcast %4 : tensor<32x1xf32, #blocked> -> tensor<32x16xf32, #blocked> - %6 = triton_gpu.convert_layout %5 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %7 = tt.dot %cst_2, %6, %cst_3 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<32x16xf32, #mma1> + %6 = ttg.convert_layout %5 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %7 = tt.dot %cst_2, %6, %cst_3 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<32x16xf32, #mma1> %addr = tt.splat %arg0 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> - %8 = triton_gpu.convert_layout %7 : tensor<32x16xf32, #mma1> -> tensor<32x16xf32, #blocked> + %8 = ttg.convert_layout %7 : tensor<32x16xf32, #mma1> -> tensor<32x16xf32, #blocked> tt.store %addr, %8 : tensor<32x16x!tt.ptr, #blocked> tt.return } @@ -2688,18 +2688,18 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 2, 1], order = [4, 0, 1, 2, 3]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 2, 2, 1, 1], order = [4, 0, 3, 2, 1]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 0, 1, 2, 3], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:100", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 2, 1], order = [4, 0, 1, 2, 3]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 2, 2, 1, 1], order = [4, 0, 3, 2, 1]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 0, 1, 2, 3], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: lift_convert_to_local_load // CHECK-NOT: convert_layout // CHECK: tt.return - tt.func public @lift_convert_to_local_load(%arg0 : !triton_gpu.memdesc<2x1x32x4x4xi8, #shared, #triton_gpu.shared_memory, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> { - %1 = triton_gpu.local_load %arg0 : !triton_gpu.memdesc<2x1x32x4x4xi8, #shared, #triton_gpu.shared_memory, mutable> -> tensor<2x1x32x4x4xi8, #blocked> + tt.func public @lift_convert_to_local_load(%arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #ttg.shared_memory, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> { + %1 = ttg.local_load %arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #ttg.shared_memory, mutable> -> tensor<2x1x32x4x4xi8, #blocked> %2 = tt.trans %1 {order = array} : tensor<2x1x32x4x4xi8, #blocked> -> tensor<2x4x32x1x4xi8, #blocked1> - %3 = triton_gpu.convert_layout %2 : tensor<2x4x32x1x4xi8, #blocked1> -> tensor<2x4x32x1x4xi8, #blocked2> + %3 = ttg.convert_layout %2 : tensor<2x4x32x1x4xi8, #blocked1> -> tensor<2x4x32x1x4xi8, #blocked2> tt.return %3 : tensor<2x4x32x1x4xi8, #blocked2> } } diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 5e244889fb..4346e1697a 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -1,27 +1,27 @@ // RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -canonicalize | FileCheck %s -#Cv2 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#Av2k1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> -#Bv2k1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> -#Av2k2 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> -#Bv2k2 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> -#Av2k4 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> -#Bv2k4 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> -#ALR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> -#BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> +#Cv2 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#Av2k1 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> +#Bv2k1 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> +#Av2k2 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> +#Bv2k2 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> +#Av2k4 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> +#Bv2k4 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> +#ALR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> +#BLR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK: tt.func @push_elementwise // CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> // CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]] -// CHECK-SAME: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> +// CHECK-SAME: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> // CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> tt.func @push_elementwise( %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -31,8 +31,8 @@ tt.func @push_elementwise( %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> %af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> - %dota = triton_gpu.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> - %dotb = triton_gpu.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %dota = ttg.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -40,7 +40,7 @@ tt.func @push_elementwise( // CHECK: tt.func @succeeds_if_arg_is_not_convert_layout // CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] // CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] // CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] // CHECK: %[[C:.*]] = tt.dot %[[AF16]] @@ -50,18 +50,18 @@ tt.func @succeeds_if_arg_is_not_convert_layout( %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> - %dotai8 = triton_gpu.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> + %dotai8 = ttg.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> - %dotb = triton_gpu.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } // CHECK: tt.func @push_inline_asm_op // CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] // CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] // CHECK: %[[AF16:.*]] = tt.elementwise_inline_asm {{.*}} %[[AF8E5]] // CHECK: %[[C:.*]] = tt.dot %[[AF16]] @@ -73,7 +73,7 @@ tt.func @push_inline_asm_op( %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> %dotaf8 = tt.bitcast %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> %dota = tt.elementwise_inline_asm "{ cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; }" {constraints = "=r,r", packed_element = 2 : i32, pure = true} %dotaf8 : tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> - %dota_cvt = triton_gpu.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %dota_cvt = ttg.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> %newc = tt.dot %dota_cvt, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -82,23 +82,23 @@ tt.func @push_inline_asm_op( // ----- -#blockedA = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { -// CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> // CHECK: tt.func @push_convert_both_operands // CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> // CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @push_convert_both_operands( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -107,9 +107,9 @@ tt.func @push_convert_both_operands( %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> - %al = triton_gpu.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %bl = triton_gpu.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } @@ -117,25 +117,25 @@ tt.func @push_convert_both_operands( // ----- -#blockedA = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { -// CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> // CHECK: tt.func @update_kwidth_slice -// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> // CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @update_kwidth_slice( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -146,9 +146,9 @@ tt.func @update_kwidth_slice( %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> - %al = triton_gpu.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %bl = triton_gpu.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } @@ -156,142 +156,142 @@ tt.func @update_kwidth_slice( // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> -tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !triton_gpu.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ - %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !triton_gpu.memdesc<128x64xf16, #shared1> - %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !triton_gpu.memdesc<128x64xf16, #shared1> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = ttg.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: ttng.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !ttg.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = ttg.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !ttg.memdesc<128x64xf16, #shared1> + %r = ttng.warp_group_dot %A, %arg1, %arg2 : !ttg.memdesc<128x64xf16, #shared1> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !triton_gpu.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> -tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !triton_gpu.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ - %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !triton_gpu.memdesc<128x64xf8E5M2, #shared1> - %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !triton_gpu.memdesc<128x64xf8E5M2, #shared1> * !triton_gpu.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = ttg.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: ttng.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !ttg.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = ttg.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !ttg.memdesc<128x64xf8E5M2, #shared1> + %r = ttng.warp_group_dot %A, %arg1, %arg2 : !ttg.memdesc<128x64xf8E5M2, #shared1> * !ttg.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @a_impl -// CHECK-NOT: %[[SELECT:.*]] = arith.select {{.*}} : tensor<128x128xi1, #triton_gpu.dot_op<{{.*}}>, tensor<128x128xf16, #triton_gpu.dot_op<{{.*}}> +// CHECK-NOT: %[[SELECT:.*]] = arith.select {{.*}} : tensor<128x128xi1, #ttg.dot_op<{{.*}}>, tensor<128x128xf16, #ttg.dot_op<{{.*}}> tt.func @a_impl(%pa: tensor<128x128x!tt.ptr, #blocked>) -> tensor<128x128xf32, #mma> { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_3 = arith.constant dense<5> : tensor<128x1xi32, #blocked> %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked> %tl = tt.load %pa : tensor<128x128x!tt.ptr, #blocked> - %tr = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %te = tt.expand_dims %tr {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %tr = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %te = tt.expand_dims %tr {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> %tc = arith.cmpi slt, %te, %cst_3 : tensor<128x1xi32, #blocked> %tb = tt.broadcast %tc : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> %ts = arith.select %tb, %tl, %cst_4 : tensor<128x128xi1, #blocked>, tensor<128x128xf16, #blocked> - %conv = triton_gpu.convert_layout %ts : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %td = tt.dot %cst_0, %conv, %cst : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %conv = ttg.convert_layout %ts : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %td = tt.dot %cst_0, %conv, %cst : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> tt.return %td : tensor<128x128xf32, #mma> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_push_elementwise // CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> -// CHECK: %[[A_DOTOP:.*]] = triton_gpu.convert_layout %[[A_BLOCK]] : tensor<128x64xbf16, #blocked> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_DOTOP]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> - tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !triton_gpu.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ +// CHECK: %[[A_DOTOP:.*]] = ttg.convert_layout %[[A_BLOCK]] : tensor<128x64xbf16, #blocked> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_DOTOP]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> %a = tt.fp_to_fp %a_bf16 : tensor<128x64xbf16, #blocked> -> tensor<128x64xf16, #blocked> - %dota = triton_gpu.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !triton_gpu.memdesc<128x64xf16, #shared1> - %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !triton_gpu.memdesc<128x64xf16, #shared1> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_push_elementwise_chained -// CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> -// CHECK: %[[A_DOTOP:.*]] = triton_gpu.convert_layout %[[A_BLOCK]] : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> - tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !triton_gpu.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ +// CHECK: %[[A_DOTOP:.*]] = ttg.convert_layout %[[A_BLOCK]] : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> %a_i8 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> %a_f16 = arith.sitofp %a_i8 : tensor<128x64xi8, #blocked> to tensor<128x64xf16, #blocked> %a_scaled = arith.mulf %a_f16, %cst : tensor<128x64xf16, #blocked> %a_negated = arith.negf %a_scaled : tensor<128x64xf16, #blocked> - %dota = triton_gpu.local_alloc %a_negated: (tensor<128x64xf16, #blocked>) -> !triton_gpu.memdesc<128x64xf16, #shared1> - %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !triton_gpu.memdesc<128x64xf16, #shared1> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dota = ttg.local_alloc %a_negated: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: mma_reorder_transpose -// CHECK: triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_trans -// CHECK: triton_nvidia_gpu.warp_group_dot - tt.func @mma_reorder_transpose(%t: tensor<64x128xf16, #blocked1>, %dotb: !triton_gpu.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ +// CHECK: ttg.local_alloc +// CHECK: ttg.memdesc_trans +// CHECK: ttng.warp_group_dot + tt.func @mma_reorder_transpose(%t: tensor<64x128xf16, #blocked1>, %dotb: !ttg.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %a = tt.trans %t {order = array} : tensor<64x128xf16, #blocked1> -> tensor<128x64xf16, #blocked> - %dota = triton_gpu.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !triton_gpu.memdesc<128x64xf16, #shared1> - %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !triton_gpu.memdesc<128x64xf16, #shared1> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: mmav2_reorder_transpose -// CHECK: triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_trans -// CHECK: triton_gpu.local_load +// CHECK: ttg.local_alloc +// CHECK: ttg.memdesc_trans +// CHECK: ttg.local_load // CHECK: tt.dot - tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %a = tt.trans %t {order = array} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked> - %cv = triton_gpu.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %cv = ttg.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index f83acb21f1..c85df2ff64 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -1,45 +1,45 @@ // RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: matmul_like_fence tt.func public @matmul_like_fence(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked2>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !triton_gpu.memdesc<128x128xf16, #shared> - %1 = triton_gpu.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !triton_gpu.memdesc<128x64xf16, #shared1> - // CHECK: triton_nvidia_gpu.fence_async_shared - %2 = triton_nvidia_gpu.warp_group_dot %0, %1, %cst : !triton_gpu.memdesc<128x128xf16, #shared> * !triton_gpu.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared> + %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1> + // CHECK: ttng.fence_async_shared + %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared> * !ttg.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: fence_outside_loop tt.func public @fence_outside_loop(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %c64_i32 = arith.constant 64 : i32 %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !triton_gpu.memdesc<128x128xf16, #shared> - %1 = triton_gpu.local_alloc %arg1 : (tensor<128x64xf16, #blocked>) -> !triton_gpu.memdesc<128x64xf16, #shared1> - // CHECK: triton_nvidia_gpu.fence_async_shared + %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared> + %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1> + // CHECK: ttng.fence_async_shared // CHECK: scf.for - // CHECK-NOT: triton_nvidia_gpu.fence_async_shared - // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NOT: ttng.fence_async_shared + // CHECK: ttng.warp_group_dot scf.for %iv0 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { scf.for %iv1 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { - %2 = triton_nvidia_gpu.warp_group_dot %0, %1, %cst : !triton_gpu.memdesc<128x128xf16, #shared> * !triton_gpu.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared> * !ttg.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> } } tt.return diff --git a/test/TritonGPU/global_scratch_alloc.mlir b/test/TritonGPU/global_scratch_alloc.mlir index a715b30d61..1c4d5bb2ef 100644 --- a/test/TritonGPU/global_scratch_alloc.mlir +++ b/test/TritonGPU/global_scratch_alloc.mlir @@ -1,33 +1,33 @@ // RUN: triton-opt %s -split-input-file --tritongpu-global-scratch-memory-allocation | FileCheck %s -// CHECK: module attributes {triton_gpu.global_scratch_memory_alignment = 128 : i32, triton_gpu.global_scratch_memory_size = 256 : i32{{.*}}} -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK: @test_alloc{{.*}}triton_gpu.global_scratch_memory_alignment = 128 : i32, triton_gpu.global_scratch_memory_size = 256 : i32 +// CHECK: module attributes {ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32{{.*}}} +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK: @test_alloc{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32 tt.func public @test_alloc() -> (!tt.ptr, !tt.ptr) { - // CHECK: triton_gpu.global_scratch_memory_offset = 0 - %0 = triton_gpu.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr - // CHECK: triton_gpu.global_scratch_memory_offset = 128 - %1 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr + // CHECK: ttg.global_scratch_memory_offset = 128 + %1 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr tt.return %0, %1 : !tt.ptr, !tt.ptr } } // ----- -// CHECK: module attributes {triton_gpu.global_scratch_memory_alignment = 128 : i32, triton_gpu.global_scratch_memory_size = 256 : i32{{.*}}} -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK: @helper1{{.*}}triton_gpu.global_scratch_memory_alignment = 128 : i32, triton_gpu.global_scratch_memory_size = 128 : i32 +// CHECK: module attributes {ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32{{.*}}} +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK: @helper1{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 128 : i32 tt.func private @helper1() -> (!tt.ptr) { - // CHECK: triton_gpu.global_scratch_memory_offset = 0 - %0 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr tt.return %0 : !tt.ptr } -// CHECK: @test_function{{.*}}triton_gpu.global_scratch_memory_alignment = 128 : i32, triton_gpu.global_scratch_memory_size = 256 : i32 +// CHECK: @test_function{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32 tt.func public @test_function() -> (!tt.ptr, !tt.ptr) { - // CHECK: triton_gpu.global_scratch_memory_offset = 0 - %0 = triton_gpu.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr - // CHECK: triton_gpu.global_scratch_memory_offset = 128 + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr + // CHECK: ttg.global_scratch_memory_offset = 128 %1 = tt.call @helper1() : () -> (!tt.ptr) tt.return %0, %1 : !tt.ptr, !tt.ptr } diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index 26a7c0773b..8c90b013cc 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -1,78 +1,78 @@ // RUN: triton-opt %s -split-input-file -verify-diagnostics -// expected-error@+2 {{triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: 2}} -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}> +// expected-error@+2 {{ttg.dot_op opIdx paramenter can be 0 or 1, got: 2}} +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#dot_op = #ttg.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is not supported when the parent is a blocked layout}} -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #blocked, kWidth = 8}> +// expected-error@+2 {{ttg.dot_op kWidth parameter is not supported when the parent is a blocked layout}} +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #blocked, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}} -#mma = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// expected-error@+2 {{ttg.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 0, parent = #mma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 0, parent = #mma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op opIdx parameter must be 0 for Hopper MMA parent, since Hopper WGMMA only allows first operand to be in registers}} -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// expected-error@+2 {{ttg.dot_op opIdx parameter must be 0 for Hopper MMA parent, since Hopper WGMMA only allows first operand to be in registers}} +#mma = #ttg.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for MFMA parent}} -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mfma}> +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for MFMA parent}} +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mfma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} -#wmma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma}> +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} -#wmma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}> +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} -#wmma = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 16}> +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 16}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} -#wmma = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 4}> +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 4}> // ----- // expected-error@+1 {{major version must be in the [0, 3] range}} -#mfma = #triton_gpu.amd_mfma<{versionMajor = 10, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 10, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> // ----- // expected-error@+1 {{minor version must be 0}} -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 5, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 5, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> // ----- // expected-error@+1 {{(M, N) cases other than (32, 32) or (16, 16) unimplemented}} -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [16, 8], isTransposed = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [16, 8], isTransposed = false}> diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index 5a91a3cc0c..0a494006af 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -1,54 +1,54 @@ // RUN: triton-opt --split-input-file %s --verify-diagnostics -tt.func public @subview_element_ty(%arg0: !triton_gpu.memdesc<8x16xf32>) { +tt.func public @subview_element_ty(%arg0: !ttg.memdesc<8x16xf32>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{element type}} - %a = triton_gpu.memdesc_subview %arg0[%zero, %zero] : !triton_gpu.memdesc<8x16xf32> -> !triton_gpu.memdesc<8x16xf16> + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<8x16xf16> tt.return } // ----- -tt.func public @too_many_offsets(%arg0: !triton_gpu.memdesc<8x16xf32>) { +tt.func public @too_many_offsets(%arg0: !ttg.memdesc<8x16xf32>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{offsets}} - %a = triton_gpu.memdesc_subview %arg0[%zero, %zero, %zero] : !triton_gpu.memdesc<8x16xf32> -> !triton_gpu.memdesc + %a = ttg.memdesc_subview %arg0[%zero, %zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc tt.return } // ----- -tt.func public @too_few_offsets(%arg0: !triton_gpu.memdesc<8x16xf32>) { +tt.func public @too_few_offsets(%arg0: !ttg.memdesc<8x16xf32>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{offsets}} - %a = triton_gpu.memdesc_subview %arg0[%zero] : !triton_gpu.memdesc<8x16xf32> -> !triton_gpu.memdesc + %a = ttg.memdesc_subview %arg0[%zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc tt.return } // ----- -tt.func public @result_rank_too_large(%arg0: !triton_gpu.memdesc<8x16xf32>) { +tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<8x16xf32>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{result rank}} - %a = triton_gpu.memdesc_subview %arg0[%zero, %zero] : !triton_gpu.memdesc<8x16xf32> -> !triton_gpu.memdesc<3x8x16xf32> + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<3x8x16xf32> tt.return } // ----- -tt.func public @result_dim_too_large(%arg0: !triton_gpu.memdesc<8x16xf32>) { +tt.func public @result_dim_too_large(%arg0: !ttg.memdesc<8x16xf32>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{result shape}} - %a = triton_gpu.memdesc_subview %arg0[%zero, %zero] : !triton_gpu.memdesc<8x16xf32> -> !triton_gpu.memdesc<32xf32> + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<32xf32> tt.return } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{element types of operands A and B must have same bit width}} %D = tt.dot %A, %B, %C : tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> @@ -58,10 +58,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching encoding between A and B operands}} %D = tt.dot %A, %B, %C : tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> @@ -71,10 +71,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32>) { // expected-error@+1 {{miss encoding of C operand}} %D = tt.dot %A, %B, %C : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32> @@ -84,10 +84,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching kWidth between A and B operands}} %D = tt.dot %A, %B, %C : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> diff --git a/test/TritonGPU/loop-pipeline-cuda.mlir b/test/TritonGPU/loop-pipeline-cuda.mlir index fe8f45e92f..4842f7ed28 100644 --- a/test/TritonGPU/loop-pipeline-cuda.mlir +++ b/test/TritonGPU/loop-pipeline-cuda.mlir @@ -1,11 +1,11 @@ // RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tt.func @load_two_users tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> @@ -20,39 +20,39 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.for // CHECK: tt.dot // CHECK: tt.dot - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield - // CHECK: triton_gpu.async_wait {num = 0 : i32} + // CHECK: ttg.async_wait {num = 0 : i32} %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> - %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> - %25 = triton_gpu.memdesc_trans %24 {order=array} : !triton_gpu.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> - %26 = triton_gpu.local_load %25 : !triton_gpu.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -61,14 +61,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> +// CHECK-NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> %c64_i32 = arith.constant 64 : i32 @@ -78,10 +78,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %1 = arith.muli %0, %c64_i32 : i32 %2 = tt.get_program_id y : i32 %3 = tt.load %arg3 : !tt.ptr - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> @@ -92,10 +92,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %15 = arith.extsi %14 : i32 to i64 %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> - %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> @@ -105,8 +105,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> - %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> @@ -117,10 +117,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> - %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> - %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> @@ -139,15 +139,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> - %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !triton_gpu.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> - %73 = triton_gpu.memdesc_trans %72 {order=array} : !triton_gpu.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> - %74 = triton_gpu.local_load %73 : !triton_gpu.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %71 = ttg.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #ttg.shared_memory> + %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #ttg.shared_memory> -> !ttg.memdesc<64x32xf32, #shared1, #ttg.shared_memory> + %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #ttg.shared_memory> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> - %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = ttg.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> scf.yield %79 : tensor<64x32xf32, #mma> } %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> @@ -155,7 +155,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> - %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + %69 = ttg.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> tt.return } @@ -163,21 +163,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @matmul_tma -// CHECK-DAG: triton_gpu.local_alloc : () -> !triton_gpu.memdesc<3x128x64xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !triton_gpu.memdesc<3x64x256xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !triton_gpu.memdesc<3xi64, #{{.+}}, #triton_gpu.shared_memory, mutable> -// CHECK-COUNT-3: triton_nvidia_gpu.init_barrier -// CHECK-COUNT-4: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #{{.+}}, #ttg.shared_memory, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #{{.+}}, #ttg.shared_memory, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3xi64, #{{.+}}, #ttg.shared_memory, mutable> +// CHECK-COUNT-3: ttng.init_barrier +// CHECK-COUNT-4: ttng.async_tma_copy_global_to_local // CHECK: scf.for -// CHECK: triton_nvidia_gpu.wait_barrier -// CHECK-NOT: triton_nvidia_gpu.wait_barrier -// CHECK-COUNT-2: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK: ttng.wait_barrier +// CHECK-NOT: ttng.wait_barrier +// CHECK-COUNT-2: ttng.async_tma_copy_global_to_local // CHECK: scf.yield tt.func public @matmul_tma(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) -> tensor<128x256xf32, #mma> { %c256_i32 = arith.constant 256 : i32 @@ -187,10 +187,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> - %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !triton_gpu.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> - %5 = triton_nvidia_gpu.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + %4 = ttg.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> + %5 = ttng.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> -> tensor<128x256xf32, #mma> %6 = arith.addi %arg5, %c64_i32 : i32 scf.yield %5, %6 : tensor<128x256xf32, #mma>, i32 } diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 641ff165d3..69868cf50c 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -1,11 +1,11 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @load_two_users tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> @@ -20,37 +20,37 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK: triton_gpu.local_store + // CHECK: ttg.local_store // CHECK: scf.for // CHECK: tt.load // CHECK: tt.dot // CHECK: tt.dot - // CHECK: triton_gpu.local_store + // CHECK: ttg.local_store // CHECK: scf.yield %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> - %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> - %25 = triton_gpu.memdesc_trans %24 {order=array} : !triton_gpu.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> - %26 = triton_gpu.local_load %25 : !triton_gpu.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory, mutable> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory, mutable> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -60,14 +60,14 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 // ----- // CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de -// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> +// CHECK-NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> %c64_i32 = arith.constant 64 : i32 @@ -77,10 +77,10 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %1 = arith.muli %0, %c64_i32 : i32 %2 = tt.get_program_id y : i32 %3 = tt.load %arg3 : !tt.ptr - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> @@ -91,10 +91,10 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %15 = arith.extsi %14 : i32 to i64 %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> - %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> @@ -104,8 +104,8 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> - %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> @@ -116,10 +116,10 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> - %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> - %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> @@ -138,15 +138,15 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> - %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !triton_gpu.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> - %73 = triton_gpu.memdesc_trans %72 {order=array} : !triton_gpu.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> - %74 = triton_gpu.local_load %73 : !triton_gpu.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %71 = ttg.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #ttg.shared_memory, mutable> + %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x32xf32, #shared1, #ttg.shared_memory, mutable> + %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #ttg.shared_memory, mutable> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> - %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = ttg.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> scf.yield %79 : tensor<64x32xf32, #mma> } %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> @@ -154,7 +154,7 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> - %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + %69 = ttg.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> tt.return } @@ -172,8 +172,8 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 // CHECK: gpu.barrier // CHECK: tt.store -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @add_barrier_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -201,16 +201,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] -// CHECK: #triton_gpu.shared<{{.*}} order = [2, 1, 0] -// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] +// CHECK-NOT: #ttg.shared<{{.*}} order = [2, 0, 1] +// CHECK: #ttg.shared<{{.*}} order = [2, 1, 0] +// CHECK-NOT: #ttg.shared<{{.*}} order = [2, 0, 1] // CHECK-LABEL: tt.func public @slowest_dim_is_batch -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { tt.func public @slowest_dim_is_batch(%arg0: tensor<1x512x!tt.ptr, #blocked2>, %arg1: tensor<64x8x32x!tt.ptr, #blocked1>, %arg2: tensor<64x1x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x1x32xf32, #blocked> %cst_0 = arith.constant dense<512> : tensor<1x512xi32, #blocked2> @@ -222,9 +222,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %39 = tt.load %arg9 : tensor<1x512x!tt.ptr, #blocked2> %40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr, #blocked1> %41 = tt.reshape %39 allow_reorder : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> - %43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> + %43 = ttg.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %44 = ttg.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> %46 = tt.addptr %arg9, %cst_0 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> %47 = tt.addptr %arg10, %cst_1 : tensor<64x8x32x!tt.ptr, #blocked1>, tensor<64x8x32xi32, #blocked1> scf.yield %45, %46, %47 : tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1> @@ -238,25 +238,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced // CHECK-LABEL: loop_with_dot_and_transpose -// CHECK: triton_gpu.local_alloc {{.*}}, mutable> -// CHECK: triton_gpu.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable> +// CHECK: ttg.local_alloc {{.*}}, mutable> +// CHECK: ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable> -#blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1201", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1201", "ttg.threads-per-warp" = 32 : i32} { tt.func public @loop_with_dot_and_transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32, %arg4: tensor<32x32x!tt.ptr, #blocked1>, %arg5: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> %0 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<32x32xf32, #blocked>) : i32 { %2 = tt.load %arg4 : tensor<32x32x!tt.ptr, #blocked1> - %3 = triton_gpu.local_alloc %2 : (tensor<32x32xf32, #blocked1>) -> !triton_gpu.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory> - %4 = triton_gpu.memdesc_trans %3 {order = array} : !triton_gpu.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory> - %5 = triton_gpu.local_load %4 : !triton_gpu.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %6 = triton_gpu.convert_layout %2 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %7 = tt.dot %6, %5, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf32, #blocked> + %3 = ttg.local_alloc %2 : (tensor<32x32xf32, #blocked1>) -> !ttg.memdesc<32x32xf32, #shared, #ttg.shared_memory> + %4 = ttg.memdesc_trans %3 {order = array} : !ttg.memdesc<32x32xf32, #shared, #ttg.shared_memory> -> !ttg.memdesc<32x32xf32, #shared1, #ttg.shared_memory> + %5 = ttg.local_load %4 : !ttg.memdesc<32x32xf32, #shared1, #ttg.shared_memory> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %6 = ttg.convert_layout %2 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %7 = tt.dot %6, %5, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf32, #blocked> scf.yield %7 : tensor<32x32xf32, #blocked> } tt.store %arg5, %0 : tensor<32x32x!tt.ptr, #blocked> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index d358be4d97..776caf099b 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -3,57 +3,57 @@ // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> // CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !triton_gpu.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> +// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #ttg.shared_memory, mutable> // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[BSUB:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #ttg.shared_memory, mutable> // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[ASUB1:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_A1:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] +// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[BSUB1:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B1:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] -// CHECK-DAG: %[[A0:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.local_load %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: %[[ASUB3:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.async_copy_global_to_local {{.*}}, %[[ASUB3]] -// CHECK: %[[BSUB3:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.async_copy_global_to_local {{.*}}, %[[BSUB3]] +// CHECK: %[[ASUB3:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[ASUB3]] +// CHECK: %[[BSUB3:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[BSUB3]] // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_A:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[NEXT_A:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { @@ -82,9 +82,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -98,61 +98,61 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // ----- -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> // CHECK-LABEL: tt.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK: scf.for -// CHECK-DAG: %[[A0:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.local_load %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_A:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]] -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[NEXT_A:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]] +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] -// CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK scf.yield -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { @@ -181,9 +181,9 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -198,39 +198,39 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // ----- -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> // CHECK-LABEL: tt.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_b0_dot_op:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { @@ -251,7 +251,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_mask = arith.constant dense : tensor<32x128xi1, #BL> %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> @@ -261,7 +261,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> @@ -275,48 +275,48 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // TODO: MCast is not supported yet //// 4 warps, TMA Load //// matmul: 128x32 @ 32x128 -> 128x128 -//#C = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [4, 1]}> -//#SA = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset=true}> -//#SB = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset=true}> -//#BA = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -//#BB = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1]}> +//#C = #ttg.nvidia_mma<{versionMajor = 3, warpsPerCTA = [4, 1]}> +//#SA = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset=true}> +//#SB = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset=true}> +//#BA = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +//#BB = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1]}> //// C-HECK: func @matmul_loop //// C-HECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 //// C-HECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 //// C-HECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 //// C-HECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 -//// C-HECK: %[[MBARRIER_AB:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} -//// C-HECK: %[[EMPTY_BARRIER_B:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 2 : i32} -//// C-HECK: %[[ABUFFER:.*]] = triton_gpu.alloc -//// C-HECK: %[[MBARRIER_AB0:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][%c0_i32] -//// C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[MBARRIER_AB0]] -//// C-HECK: %[[A0BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] -//// C-HECK: %[[BBUFFER:.*]] = triton_gpu.alloc -//// C-HECK: %[[EMPTY_BARRIER_B0:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][%c0_i32] -//// C-HECK: triton_nvidia_gpu.mbarrier_wait %[[EMPTY_BARRIER_B0]], %true -//// C-HECK: %[[B0BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] -//// C-HECK: %[[MBARRIER_AB1:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][%c1_i32] -//// C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[MBARRIER_AB1]] -//// C-HECK: %[[A1BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] -//// C-HECK: %[[B1BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] -//// C-HECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -//// C-HECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] +//// C-HECK: %[[MBARRIER_AB:.*]] = ttng.alloc_mbarrier {count = 1 : i32} +//// C-HECK: %[[EMPTY_BARRIER_B:.*]] = ttng.alloc_mbarrier {count = 2 : i32} +//// C-HECK: %[[ABUFFER:.*]] = ttg.alloc +//// C-HECK: %[[MBARRIER_AB0:.*]] = ttng.extract_mbarrier %[[MBARRIER_AB]][%c0_i32] +//// C-HECK: ttng.mbarrier_arrive %[[MBARRIER_AB0]] +//// C-HECK: %[[A0BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] +//// C-HECK: %[[BBUFFER:.*]] = ttg.alloc +//// C-HECK: %[[EMPTY_BARRIER_B0:.*]] = ttng.extract_mbarrier %[[EMPTY_BARRIER_B]][%c0_i32] +//// C-HECK: ttng.mbarrier_wait %[[EMPTY_BARRIER_B0]], %true +//// C-HECK: %[[B0BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] +//// C-HECK: %[[MBARRIER_AB1:.*]] = ttng.extract_mbarrier %[[MBARRIER_AB]][%c1_i32] +//// C-HECK: ttng.mbarrier_arrive %[[MBARRIER_AB1]] +//// C-HECK: %[[A1BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] +//// C-HECK: %[[B1BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] +//// C-HECK: %[[A0:.*]] = ttg.extract_slice %[[A1BUFFER]][0, 0, 0] +//// C-HECK: %[[B0:.*]] = ttg.extract_slice %[[B1BUFFER]][0, 0, 0] //// C-HECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] -// // C-HECK: %[[MBARRIER_AB_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] -// // C-HECK: triton_nvidia_gpu.mbarrier_wait %[[MBARRIER_AB_ITER]], {{.*}} -// // C-HECK: triton_nvidia_gpu.warp_group_dot %[[arg_a0]], %[[arg_b0]], {{.*}} -// // C-HECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} -// // C-HECK: %[[EMPTY_BARRIER_B_ITER_ARRIVE:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] -// // C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[EMPTY_BARRIER_B_ITER_ARRIVE]] -// // C-HECK: %[[MBARRIER_AB_NEXT_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] -// // C-HECK: %[[NEXT_A_BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] -// // C-HECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] -// // C-HECK: %[[EMPTY_BARRIER_B_ITER_WAIT:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] -// // C-HECK: triton_nvidia_gpu.mbarrier_wait %[[EMPTY_BARRIER_B_ITER_WAIT]], {{.*}} -// // C-HECK: %[[NEXT_B_BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] -// // C-HECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] +// // C-HECK: %[[MBARRIER_AB_ITER:.*]] = ttng.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] +// // C-HECK: ttng.mbarrier_wait %[[MBARRIER_AB_ITER]], {{.*}} +// // C-HECK: ttng.warp_group_dot %[[arg_a0]], %[[arg_b0]], {{.*}} +// // C-HECK: ttng.warp_group_dot_wait {{.*}} +// // C-HECK: %[[EMPTY_BARRIER_B_ITER_ARRIVE:.*]] = ttng.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] +// // C-HECK: ttng.mbarrier_arrive %[[EMPTY_BARRIER_B_ITER_ARRIVE]] +// // C-HECK: %[[MBARRIER_AB_NEXT_ITER:.*]] = ttng.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] +// // C-HECK: %[[NEXT_A_BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] +// // C-HECK: %[[NEXT_A:.*]] = ttg.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] +// // C-HECK: %[[EMPTY_BARRIER_B_ITER_WAIT:.*]] = ttng.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] +// // C-HECK: ttng.mbarrier_wait %[[EMPTY_BARRIER_B_ITER_WAIT]], {{.*}} +// // C-HECK: %[[NEXT_B_BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] +// // C-HECK: %[[NEXT_B:.*]] = ttg.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] // // C-HECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} -//module attributes {"triton_gpu.num-ctas" = 2 : i32, "triton_gpu.num-warps" = 4 : i32} { +//module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { // tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // %A : !tt.ptr {tt.divisibility = 16 : i32}, // %B : !tt.ptr {tt.divisibility = 16 : i32}) -> (!tt.ptr, 1>, !tt.ptr, 1>, tensor<128x128xf32, #C>) { @@ -333,9 +333,9 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // %a = tt.load %a_tileptr : !tt.ptr, 1> // %b = tt.load %b_tileptr : !tt.ptr, 1> // -// %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !triton_gpu.memdesc<128x32xf16, #SA, #triton_gpu.shared_memory> -// %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !triton_gpu.memdesc<32x128xf16, #SB, #triton_gpu.shared_memory> -// %c = triton_nvidia_gpu.warp_group_dot %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> +// %sa = ttg.local_alloc %a : (tensor<128x32xf16, #BA>) -> !ttg.memdesc<128x32xf16, #SA, #ttg.shared_memory> +// %sb = ttg.local_alloc %b : (tensor<32x128xf16, #BB>) -> !ttg.memdesc<32x128xf16, #SB, #ttg.shared_memory> +// %c = ttng.warp_group_dot %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> // %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr, 1> @@ -348,13 +348,13 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_chained_single_load tt.func @dot_chained_single_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x64xf32, #mma> { %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> @@ -370,36 +370,36 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_commit_group + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group // CHECK: scf.yield %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>) : i32 { %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %21 = ttng.warp_group_dot %19, %20, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %23 = triton_gpu.memdesc_trans %20 {order=array} : !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> - %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !triton_gpu.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %23 = ttg.memdesc_trans %20 {order=array} : !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> + %24 = ttg.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %25 = ttng.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -419,35 +419,35 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for - // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 1 : i32} - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_commit_group + // CHECK: ttg.async_wait {{.*}} {num = 2 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32} + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group // CHECK: scf.if - // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} // CHECK: arith.mulf // CHECK: scf.yield // CHECK: scf.yield - // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -465,13 +465,13 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: two_accumulator_escape tt.func @two_accumulator_escape(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> @@ -481,45 +481,45 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - // CHECK: %[[ALLOC1:.+]] = triton_gpu.local_alloc - // CHECK: %[[ALLOC2:.+]] = triton_gpu.local_alloc + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + // CHECK: %[[ALLOC1:.+]] = ttg.local_alloc + // CHECK: %[[ALLOC2:.+]] = ttg.local_alloc // CHECK: %[[R:.+]]:{{.+}} = scf.for - // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: %[[TRANS:.+]] = triton_gpu.memdesc_trans{{.*}} : !triton_gpu.memdesc - // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} %[[TRANS]] - // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} + // CHECK: %[[DOT1:.+]] = ttng.warp_group_dot{{.*}} + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} + // CHECK: %[[TRANS:.+]] = ttg.memdesc_trans{{.*}} : !ttg.memdesc + // CHECK: %[[DOT2:.+]] = ttng.warp_group_dot{{.*}} %[[TRANS]] + // CHECK: ttng.warp_group_dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} // CHECK: scf.yield - // CHECK: %{{.*}}:2 = triton_nvidia_gpu.warp_group_dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> + // CHECK: %{{.*}}:2 = ttng.warp_group_dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { - %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %arg6 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %21 = ttng.warp_group_dot %19, %20, %arg6 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %23 = triton_gpu.memdesc_trans %c {order=array} : !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !triton_gpu.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> + %25 = ttng.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -529,46 +529,46 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> // Make sure that if one of the load dot operand is not pipelined (and therefore not double buffered) we won't use // async dot. -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: no_wgmma_pipeline tt.func public @no_wgmma_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 - %cst_0 = arith.constant dense<512> : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %cst_1 = arith.constant dense<512> : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %cst_0 = arith.constant dense<512> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %cst_1 = arith.constant dense<512> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> %cst_2 = arith.constant dense<512> : tensor<128x1xi32, #blocked> %cst_3 = arith.constant dense<512> : tensor<128x1xi32, #blocked1> %cst_4 = arith.constant dense<512> : tensor<64x1xi32, #blocked1> %cst_5 = arith.constant dense<32768> : tensor<64x256xi32, #blocked1> %cst_6 = arith.constant dense<64> : tensor<128x64xi32, #blocked> - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %1 = arith.remsi %0, %cst_0 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %3 = arith.remsi %2, %cst_1 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %4 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = arith.remsi %0, %cst_0 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %3 = arith.remsi %2, %cst_1 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %4 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> %5 = arith.muli %4, %cst_2 : tensor<128x1xi32, #blocked> - %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> %8 = tt.broadcast %5 : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked> %9 = tt.broadcast %7 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> %10 = arith.addi %8, %9 : tensor<128x64xi32, #blocked> %11 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> %12 = tt.addptr %11, %10 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %14 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %14 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> %15 = arith.muli %14, %cst_4 : tensor<64x1xi32, #blocked1> - %16 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %16 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> %17 = tt.broadcast %15 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> %18 = tt.broadcast %16 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> %19 = arith.addi %17, %18 : tensor<64x256xi32, #blocked1> @@ -577,29 +577,29 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %22:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %21) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { %35 = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked> %36 = tt.load %arg6 : tensor<64x256x!tt.ptr, #blocked1> - %37 = triton_gpu.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !triton_gpu.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> - %38 = triton_gpu.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !triton_gpu.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> - // CHECK: triton_gpu.local_alloc + %37 = ttg.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !ttg.memdesc<128x64xf8E5M2, #shared, #ttg.shared_memory> + %38 = ttg.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !ttg.memdesc<64x256xf8E5M2, #shared1, #ttg.shared_memory> + // CHECK: ttg.local_alloc // CHECK: scf.for - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait - %39 = triton_nvidia_gpu.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !triton_gpu.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait + %39 = ttng.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x64xf8E5M2, #shared, #ttg.shared_memory> * !ttg.memdesc<64x256xf8E5M2, #shared1, #ttg.shared_memory> -> tensor<128x256xf32, #mma> %40 = tt.addptr %arg5, %cst_6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> %41 = tt.addptr %arg6, %cst_5 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> scf.yield %39, %40, %41 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> } %23 = arith.truncf %22#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> %26 = arith.muli %25, %cst_3 : tensor<128x1xi32, #blocked1> %27 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %28 = tt.addptr %27, %26 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %29 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %29 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> %30 = tt.broadcast %28 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> %31 = tt.broadcast %29 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> %32 = tt.addptr %30, %31 : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> %33 = tt.fp_to_fp %23 {rounding = 1 : i32} : tensor<128x256xf16, #mma> -> tensor<128x256xf8E5M2, #mma> - %34 = triton_gpu.convert_layout %33 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked1> + %34 = ttg.convert_layout %33 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked1> tt.store %32, %34 : tensor<128x256x!tt.ptr, #blocked1> tt.return } @@ -608,13 +608,13 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // A dot can be properly async if all its uses follow a synchronous MMAv3 dot. -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: async_following_sync tt.func @async_following_sync(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { %cst = arith.constant dense<64> : tensor<64x16xi32, #blocked> @@ -624,7 +624,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 @@ -643,49 +643,49 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> // CHECK: %[[LOOP:[^ :]+]]{{.*}} scf.for {{.*}} iter_args(%[[PREV_DOT2:[^ ]+]] - // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait - // CHECK: %[[DOT0:.+]] = triton_nvidia_gpu.warp_group_dot - // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait - // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK-NOT: ttng.warp_group_dot_wait + // CHECK: %[[DOT0:.+]] = ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait + // CHECK: %[[DOT1:.+]] = ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait // CHECK-DAG-SAME: %[[DOT0]] // CHECK-DAG-SAME: %[[DOT1]] // CHECK-DAG-SAME: %[[PREV_DOT2]] // CHECK-SAME: {pendings = 0 : i32} - // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot - // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK: %[[DOT2:.+]] = ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait // CHECK: scf.yield %[[DOT2]] - // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} %17:4 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%prev_dot2 = %cst_3, %arg5 = %16, %prev_dot1 = %cst_2, %prev_dot0 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { // This one can be async. - %dot0 = triton_nvidia_gpu.warp_group_dot %19, %20, %prev_dot1 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %dot0 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> // This can't be async because its result is modified before it's yielded. - %dot1 = triton_nvidia_gpu.warp_group_dot %19, %20, %prev_dot1 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %dot1 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %23 = triton_gpu.memdesc_trans %c {order=array} : !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !triton_gpu.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %dot2 = ttng.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } @@ -695,18 +695,18 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // Test pipelining of experimental_descriptor_store -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_store_pipeline tt.func public @tma_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { %1 = arith.divsi %arg4, %arg2 : i32 - // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} - // CHECK-NEXT: triton_gpu.local_store - // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared - // CHECK-NEXT: triton_nvidia_gpu.tensor_desc_to_tma_ptr - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> } tt.return @@ -714,26 +714,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_multiple_store_pipeline tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 - // CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + // CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> // CHECK: scf.for scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { %1 = arith.divsi %arg4, %arg2 : i32 %2 = arith.divsi %arg2, %arg4 : i32 - // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} - // CHECK-NEXT: triton_gpu.local_store %{{.+}}, %[[ALLOC]] - // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared - // CHECK-NEXT: triton_nvidia_gpu.tensor_desc_to_tma_ptr - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] - // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} - // CHECK-NEXT: triton_gpu.local_store %{{.+}}, %[[ALLOC]] - // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared - // CHECK-NEXT: triton_nvidia_gpu.tensor_desc_to_tma_ptr - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store %{{.+}}, %[[ALLOC]] + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store %{{.+}}, %[[ALLOC]] + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> tt.experimental_descriptor_store %arg1[%2], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> } @@ -744,28 +744,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: _kernel_matmul_dependency - tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr, #blocked>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} { - %cst = arith.constant dense<0> : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr, #blocked>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> %cst_0 = arith.constant 1.000000e+00 : f32 %c8_i32 = arith.constant 8 : i32 %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked1> - %2:4 = scf.for %arg6 = %c8_i32 to %arg3 step %c8_i32 iter_args(%arg7 = %c8_i32, %arg8 = %c8_i32, %arg9 = %cst_1, %arg10 = %arg5) -> (i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) : i32 { + %2:4 = scf.for %arg6 = %c8_i32 to %arg3 step %c8_i32 iter_args(%arg7 = %c8_i32, %arg8 = %c8_i32, %arg9 = %cst_1, %arg10 = %arg5) -> (i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) : i32 { %3 = arith.addi %arg7, %c8_i32 : i32 %4 = arith.cmpi eq, %3, %c8_i32 : i32 - %5:2 = scf.if %4 -> (i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) { + %5:2 = scf.if %4 -> (i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) { %21 = arith.addi %arg8, %c8_i32 : i32 - scf.yield %21, %arg5 : i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + scf.yield %21, %arg5 : i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> } else { - scf.yield %arg8, %arg10 : i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + scf.yield %arg8, %arg10 : i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> } %6 = arith.cmpi eq, %3, %c8_i32 : i32 %7 = scf.if %6 -> (f32) { @@ -774,16 +774,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %21 = tt.load %arg4 : !tt.ptr scf.yield %21 : f32 } - %8 = tt.splat %3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %9 = arith.addi %8, %0 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %8 = tt.splat %3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %9 = arith.addi %8, %0 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> %11 = tt.broadcast %10 : tensor<128x1xi32, #blocked1> -> tensor<128x128xi32, #blocked1> %12 = tt.addptr %1, %11 : tensor<128x128x!tt.ptr, #blocked1>, tensor<128x128xi32, #blocked1> %13 = tt.load %arg0 : tensor<128x128x!tt.ptr, #blocked> - %14 = triton_gpu.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !triton_gpu.memdesc<128x128xf8E4M3FNUZ, #shared> + %14 = ttg.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared> %15 = tt.load %12 : tensor<128x128x!tt.ptr, #blocked1> - %16 = triton_gpu.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !triton_gpu.memdesc<128x128xf8E4M3FNUZ, #shared1> - %17 = triton_nvidia_gpu.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !triton_gpu.memdesc<128x128xf8E4M3FNUZ, #shared> * !triton_gpu.memdesc<128x128xf8E4M3FNUZ, #shared1> -> tensor<128x128xf32, #mma> + %16 = ttg.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1> + %17 = ttng.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x128xf8E4M3FNUZ, #shared> * !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1> -> tensor<128x128xf32, #mma> %18 = tt.splat %7 : f32 -> tensor<128x128xf32, #mma> %19 = arith.mulf %17, %18 : tensor<128x128xf32, #mma> %20 = scf.if %6 -> (tensor<128x128xf32, #mma>) { @@ -791,7 +791,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : } else { scf.yield %19 : tensor<128x128xf32, #mma> } - scf.yield %3, %5#0, %20, %5#1 : i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + scf.yield %3, %5#0, %20, %5#1 : i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> } tt.return } @@ -800,13 +800,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- // Pipeline the if ops at the beginning and the end of the loop -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: dot_prologue_epilogue // COMMON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { @@ -820,14 +820,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> @@ -852,9 +852,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> } %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> scf.yield %acc_zero : tensor<128x16xf32, #mma1> @@ -872,13 +872,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // Verify that uses of the ops scheduled in partucular place of the loop (like epilogue if) are correctly scheduled too. -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NOCANON-LABEL: pipeline_downstream_dependencies // CHECK-NOCANON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} tt.func @pipeline_downstream_dependencies(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { @@ -893,14 +893,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> @@ -917,9 +917,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !triton_gpu.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -937,12 +937,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_lhs_registers tt.func @dot_lhs_registers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma> { %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> @@ -952,44 +952,44 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> %cst_3 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> - %cst_4 = arith.constant dense<2.0> : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_4 = arith.constant dense<2.0> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for - // CHECK: triton_gpu.local_load - // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_commit_group - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_commit_group + // CHECK: ttg.local_load + // CHECK: ttg.async_wait {{.*}} {num = 2 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group // CHECK: scf.yield %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %8, %arg6 = %16) -> (tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x16x!tt.ptr, #blocked>) : i32 { %a_block = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked1> %b_block = tt.load %arg6 : tensor<64x16x!tt.ptr, #blocked> - %a_dotop = triton_gpu.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %b_smem = triton_gpu.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %21 = triton_nvidia_gpu.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !triton_gpu.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma> + %a_dotop = ttg.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %b_smem = ttg.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %21 = ttng.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma> %25 = tt.addptr %arg5, %cst_3 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %26 = tt.addptr %arg6, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %21, %25, %26 : tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x16x!tt.ptr, #blocked> diff --git a/test/TritonGPU/loop-pipeline-indirect-load.mlir b/test/TritonGPU/loop-pipeline-indirect-load.mlir index 3fa89657c9..af260c65c8 100644 --- a/test/TritonGPU/loop-pipeline-indirect-load.mlir +++ b/test/TritonGPU/loop-pipeline-indirect-load.mlir @@ -6,11 +6,11 @@ // CHECK: async_copy_global_to_local // CHECK: async_copy_global_to_local -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @indirect_load_two_stages(%arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32) attributes {noinline = false} { %c32_i32 = arith.constant 32 : i32 %c16_i32 = arith.constant 16 : i32 @@ -22,68 +22,68 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : %7 = tt.get_program_id x : i32 %8 = arith.muli %7, %c16_i32 : i32 - %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %15 = tt.splat %8 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %18 = arith.addi %15, %10 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %15 = tt.splat %8 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %18 = arith.addi %15, %10 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> - %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> %34 = arith.extsi %arg12 : i32 to i64 %35 = arith.muli %2, %34 : i64 %36 = tt.addptr %arg2, %35 : !tt.ptr, i64 - %47 = tt.splat %arg4 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %48 = tt.addptr %47, %20 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %47 = tt.splat %arg4 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>> + %48 = tt.addptr %47, %20 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %61 = arith.extsi %59 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3> + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %61 = arith.extsi %59 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked3}>> + %63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3> - %85 = arith.extsi %22 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %85 = arith.extsi %22 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> %107 = tt.splat %36 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked3> %108 = tt.splat %34 : i64 -> tensor<32x1xi64, #blocked3> %109 = tt.broadcast %63 : tensor<1x128xi64, #blocked3> -> tensor<32x128xi64, #blocked3> %101 = tt.splat %arg5 : !tt.ptr -> tensor<16x32x!tt.ptr, #blocked1> %111:1 = scf.for %arg28 = %arg18 to %arg19 step %c32_i32 iter_args(%arg29 = %cst) -> (tensor<16x128xf32, #blocked>) : i32 { - %129 = tt.splat %arg28 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %160 = tt.addptr %48, %129 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %161 = tt.load %160 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1> + %129 = tt.splat %arg28 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %160 = tt.addptr %48, %129 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %161 = tt.load %160 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>> + %162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1> %163 = tt.broadcast %162 : tensor<1x32xi64, #blocked1> -> tensor<16x32xi64, #blocked1> %182 = tt.addptr %101, %163 : tensor<16x32x!tt.ptr, #blocked1>, tensor<16x32xi64, #blocked1> %183 = tt.load %182 : tensor<16x32x!tt.ptr, #blocked1> %197 = arith.extsi %arg28 : i32 to i64 - %198 = tt.splat %197 : i64 -> tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %199 = arith.addi %198, %85 : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3> + %198 = tt.splat %197 : i64 -> tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %199 = arith.addi %198, %85 : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3> %201 = arith.muli %200, %108 : tensor<32x1xi64, #blocked3> %202 = tt.broadcast %201 : tensor<32x1xi64, #blocked3> -> tensor<32x128xi64, #blocked3> %203 = arith.addi %202, %109 : tensor<32x128xi64, #blocked3> %204 = tt.addptr %107, %203 : tensor<32x128x!tt.ptr, #blocked3>, tensor<32x128xi64, #blocked3> %209 = tt.load %204 : tensor<32x128x!tt.ptr, #blocked3> - %210 = triton_gpu.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %211 = triton_gpu.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked> + %210 = ttg.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %211 = ttg.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked> scf.yield %212 : tensor<16x128xf32, #blocked> } - %112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3> + %112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3> %113 = tt.splat %2 : i64 -> tensor<16x1xi64, #blocked3> %114 = arith.extsi %112 : tensor<16x1xi32, #blocked3> to tensor<16x1xi64, #blocked3> %115 = arith.addi %113, %114 : tensor<16x1xi64, #blocked3> %116 = arith.extsi %arg17 : i32 to i64 %117 = tt.splat %116 : i64 -> tensor<16x1xi64, #blocked3> %118 = arith.muli %115, %117 : tensor<16x1xi64, #blocked3> - %119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> + %119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> %120 = tt.broadcast %118 : tensor<16x1xi64, #blocked3> -> tensor<16x128xi64, #blocked3> %121 = arith.extsi %119 : tensor<1x128xi32, #blocked3> to tensor<1x128xi64, #blocked3> %122 = tt.broadcast %121 : tensor<1x128xi64, #blocked3> -> tensor<16x128xi64, #blocked3> %123 = arith.addi %120, %122 : tensor<16x128xi64, #blocked3> %124 = tt.splat %arg7 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked3> %125 = tt.addptr %124, %123 : tensor<16x128x!tt.ptr, #blocked3>, tensor<16x128xi64, #blocked3> - %128 = triton_gpu.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3> + %128 = ttg.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3> tt.store %125, %128 : tensor<16x128x!tt.ptr, #blocked3> tt.return } diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 5d0cc41a66..ebdccd3b79 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -4,57 +4,57 @@ // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> // CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[BSUB:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} +// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[ASUB1:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_A1:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] +// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[BSUB1:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B1:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] -// CHECK-DAG: %[[A0:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.local_load %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[arg_b0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: %[[ASUB3:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.async_copy_global_to_local {{.*}}, %[[ASUB3]] -// CHECK: %[[BSUB3:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.async_copy_global_to_local {{.*}}, %[[BSUB3]] +// CHECK: %[[ASUB3:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[ASUB3]] +// CHECK: %[[BSUB3:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[BSUB3]] // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_A:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[NEXT_A:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] // AMD-LABEL: tt.func @matmul_loop @@ -66,18 +66,18 @@ // AMD: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG6]], %{{.*}} // AMD: %[[ADDPTR_35:.*]] = tt.addptr %[[ARG7]], %{{.*}} // AMD: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]] -// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG10]] +// AMD: %[[LOCAL_LOAD_37:.*]] = ttg.local_load %[[ARG10]] // AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_35]] -// AMD: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_39:.*]] = ttg.local_load %[[ARG11]] // AMD: %[[MULF_40:.*]] = arith.mulf %[[LOCAL_LOAD_39]], %{{.*}} // AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[MULF_40]], %[[ARG8]] // AMD: %[[ADDI_42:.*]] = arith.addi %[[ARG9]], %{{.*}} // AMD: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} // AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_45]] -// AMD: %[[MEMDESC_SUBVIEW_46:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_46]] +// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_45]] +// AMD: %[[MEMDESC_SUBVIEW_46:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_46]] // AMD: scf.yield %[[ADDPTR_34]], %[[ADDPTR_35]], %[[DOT_41]], %[[SELECT_44]], %[[MEMDESC_SUBVIEW_45]], %[[MEMDESC_SUBVIEW_46]] // AMD: } // AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]] @@ -87,8 +87,8 @@ // AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]] // AMD: %[[DIVSI_26:.*]] = arith.divsi %[[ADDI_25]], %[[STEP]] // AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %[[DIVSI_26]], %{{.*}} -// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4 -// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[LOCAL_LOAD_28:.*]] = ttg.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_29:.*]] = ttg.local_load %{{.*}}#5 // AMD: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} // AMD: %[[IF_31:.*]] = scf.if %[[CMPI_27]] // AMD: %[[DOT_33:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %{{.*}}#2 @@ -97,38 +97,38 @@ // AMD: scf.yield %{{.*}}#2 // AMD: } // AMD: %[[SELECT_32:.*]] = arith.select %[[CMPI_27]], %[[IF_31]], %{{.*}}#2 -// AMD: triton_gpu.local_dealloc %{{.*}} -// AMD: triton_gpu.local_dealloc %{{.*}} +// AMD: ttg.local_dealloc %{{.*}} +// AMD: ttg.local_dealloc %{{.*}} // Prefetch pipelining adds another stage in between global load and compute. // This stage will local_store, then local_load, creating a prefetch from shared // memory into a register buffer for compute. // // AMD_PREFETCH-LABEL: tt.func @matmul_loop -// AMD_PREFETCH: triton_gpu.local_alloc -// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: ttg.local_alloc // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: ttg.local_load // AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.dot -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: ttg.local_load // AMD_PREFETCH: scf.yield // AMD_PREFETCH: tt.dot // AMD_PREFETCH: tt.dot // AMD_PREFETCH: tt.return -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -159,9 +159,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b__ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b_ = triton_gpu.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b_ = ttg.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -177,75 +177,75 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK: scf.for -// CHECK-DAG: %[[A0:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.local_load %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_A:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]] -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[NEXT_A:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]] +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] -// CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK scf.yield // AMD-LABEL: tt.func @matmul_loop_nested // AMD: scf.for -// AMD-COUNT-2: triton_gpu.local_alloc +// AMD-COUNT-2: ttg.local_alloc // AMD-COUNT-2: tt.load -// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] -// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: %[[FOR:.*]]:6 = scf.for // AMD-COUNT-2: tt.addptr // AMD: tt.load -// AMD: triton_gpu.local_load +// AMD: ttg.local_load // AMD: tt.load -// AMD: triton_gpu.local_load +// AMD: ttg.local_load // AMD: tt.dot -// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] -// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: scf.yield -// AMD-COUNT-2: triton_gpu.local_load +// AMD-COUNT-2: ttg.local_load // AMD: %[[IF1:.*]] = scf.if // AMD: %[[DOT1:.*]] = tt.dot // AMD: scf.yield %[[DOT1]] // AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[IF1]], %[[FOR]]#2 -// AMD-COUNT-2: triton_gpu.local_dealloc +// AMD-COUNT-2: ttg.local_dealloc // AMD: scf.yield %[[SEL1]] // AMD_PREFETCH-LABEL: tt.func @matmul_loop_nested @@ -279,9 +279,9 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %loop2:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -299,62 +299,62 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_b0_dot_op:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] // AMD-LABEL: tt.func @matmul_loop_single_pipeline // AMD: %[[LOAD_10:.*]] = tt.load %{{.*}} -// AMD: %[[CONVERT_LAYOUT_11:.*]] = triton_gpu.convert_layout %[[LOAD_10]] -// AMD: %[[LOCAL_ALLOC_12:.*]] = triton_gpu.local_alloc +// AMD: %[[CONVERT_LAYOUT_11:.*]] = ttg.convert_layout %[[LOAD_10]] +// AMD: %[[LOCAL_ALLOC_12:.*]] = ttg.local_alloc // AMD: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} // AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] // AMD: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] +// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] // AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} // AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_17]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[MEMDESC_SUBVIEW_16]]) // AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG6]], %{{.*}} // AMD: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_32]] -// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG9]] +// AMD: %[[LOCAL_LOAD_30:.*]] = ttg.local_load %[[ARG9]] // AMD: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_30]], %[[ARG7]] // AMD: %[[ADDI_34:.*]] = arith.addi %[[ARG8]], %{{.*}} // AMD: %[[CMPI_35:.*]] = arith.cmpi slt, %[[ADDI_34]], %{{.*}} // AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_35]], %[[ADDI_34]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_36]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_33]], %[[MEMDESC_SUBVIEW_37]] +// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_36]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_33]], %[[MEMDESC_SUBVIEW_37]] // AMD: scf.yield %[[ADDPTR_32]], %[[DOT_31]], %[[SELECT_36]], %[[MEMDESC_SUBVIEW_37]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_12]] // AMD_PREFETCH-LABEL: tt.func @matmul_loop_single_pipeline -// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: ttg.local_alloc // AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: ttg.local_store // AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: ttg.local_load // AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: ttg.local_store // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.dot -// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: ttg.local_load // AMD_PREFETCH: scf.yield // AMD_PREFETCH: tt.dot // AMD_PREFETCH: tt.dot @@ -380,7 +380,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_mask = arith.constant dense : tensor<32x128xi1, #BL> %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> @@ -390,7 +390,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> @@ -399,23 +399,23 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, } // CHECK-LABEL: tt.func @indirect_bmm_scalar -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] // CHECK: %[[IND_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[IND_BUFFER_1:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.splat %[[IND_BUFFER_1]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] -// CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: ttg.async_wait {{.*}} {num = 2 : i32} // AMD-LABEL: tt.func @indirect_bmm_scalar -// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc -// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc // AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} // AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] @@ -426,10 +426,10 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] // AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_12]] -// AMD: %[[MEMDESC_SUBVIEW_13:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_13]] +// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_12]] +// AMD: %[[MEMDESC_SUBVIEW_13:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_13]] // AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %{{.*}} // AMD: %[[ADDPTR_15:.*]] = tt.addptr %{{.*}}, %{{.*}} // AMD: %[[SPLAT_16:.*]] = tt.splat %[[CMPI_11]] @@ -445,20 +445,20 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // AMD: %[[ADDI_41:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_42:.*]] = arith.cmpi slt, %[[ADDI_41]], %{{.*}} // AMD: %[[SELECT_43:.*]] = arith.select %[[CMPI_42]], %[[ADDI_41]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_43]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[ARG11]], %[[MEMDESC_SUBVIEW_44]] -// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_43]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_45]] +// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_43]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[ARG11]], %[[MEMDESC_SUBVIEW_44]] +// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_43]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_45]] // AMD: %[[ADDPTR_46:.*]] = tt.addptr %[[ARG8]], %{{.*}} // AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG9]], %{{.*}} // AMD: %[[LOAD_48:.*]] = tt.load %[[ADDPTR_46]] -// AMD: %[[LOCAL_LOAD_49:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[LOCAL_LOAD_49:.*]] = ttg.local_load %[[ARG13]] // AMD: %[[LOAD_50:.*]] = tt.load %[[ADDPTR_47]] // AMD: %[[MULI_51:.*]] = arith.muli %{{.*}}, %[[LOAD_50]] // AMD: %[[SPLAT_52:.*]] = tt.splat %[[MULI_51]] // AMD: %[[ADDPTR_53:.*]] = tt.addptr %{{.*}}, %[[SPLAT_52]] // AMD: %[[LOAD_54:.*]] = tt.load %[[ADDPTR_53]] -// AMD: %[[LOCAL_LOAD_55:.*]] = triton_gpu.local_load %[[ARG14]] +// AMD: %[[LOCAL_LOAD_55:.*]] = ttg.local_load %[[ARG14]] // AMD: %[[DOT_56:.*]] = tt.dot %[[LOCAL_LOAD_49]], %[[LOCAL_LOAD_55]], %[[ARG7]] // AMD: scf.yield %[[DOT_56]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_43]], %[[LOAD_48]], %[[LOAD_54]], %[[MEMDESC_SUBVIEW_44]], %[[MEMDESC_SUBVIEW_45]] // AMD: } @@ -467,12 +467,12 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // AMD: %[[ADDI_28:.*]] = arith.addi %{{.*}}#3, %{{.*}} // AMD: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} // AMD: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_31:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_30]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %{{.*}}#4, %[[MEMDESC_SUBVIEW_31]] -// AMD: %[[MEMDESC_SUBVIEW_32:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_30]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %{{.*}}#5, %[[MEMDESC_SUBVIEW_32]] -// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %{{.*}}#6 -// AMD: %[[LOCAL_LOAD_34:.*]] = triton_gpu.local_load %{{.*}}#7 +// AMD: %[[MEMDESC_SUBVIEW_31:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_30]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %{{.*}}#4, %[[MEMDESC_SUBVIEW_31]] +// AMD: %[[MEMDESC_SUBVIEW_32:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_30]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %{{.*}}#5, %[[MEMDESC_SUBVIEW_32]] +// AMD: %[[LOCAL_LOAD_33:.*]] = ttg.local_load %{{.*}}#6 +// AMD: %[[LOCAL_LOAD_34:.*]] = ttg.local_load %{{.*}}#7 // AMD: %[[IF_35:.*]] = scf.if %[[CMPI_26]] // AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_33]], %[[LOCAL_LOAD_34]], %{{.*}}#0 // AMD: scf.yield %[[DOT_41]] @@ -480,8 +480,8 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // AMD: scf.yield %{{.*}}#0 // AMD: } // AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_35]], %{{.*}}#0 -// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_31]] -// AMD: %[[LOCAL_LOAD_38:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_32]] +// AMD: %[[LOCAL_LOAD_37:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_31]] +// AMD: %[[LOCAL_LOAD_38:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_32]] // AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]] // AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]] // AMD: scf.yield %[[DOT_41]] @@ -489,30 +489,30 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // AMD: scf.yield %[[SELECT_36]] // AMD: } // AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_0]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_1]] // AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar -// AMD_PREFETCH: triton_gpu.local_alloc -// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: ttg.local_alloc // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: ttg.local_load // AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.dot -// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: ttg.local_load // AMD_PREFETCH: scf.yield // AMD_PREFETCH: tt.dot // AMD_PREFETCH: tt.dot @@ -538,8 +538,8 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %85 = tt.splat %84 : i64 -> tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 @@ -549,20 +549,20 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, } // CHECK-LABEL: tt.func @indirect_bmm_scalar_dist_one -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group // CHECK: scf.for %{{.*}} iter_args(%{{[^,]*}}, %{{[^,]*}}, %{{[^,]*}}, %[[IND_BUFFER_PREV:[^,]*]] = {{[^,]*}} // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] // CHECK: %[[IND_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[IND_BUFFER_1:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_PREV]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.splat %[[IND_BUFFER_1]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] -// CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[IND_BUFFER_0]] // AMD-LABEL: tt.func @indirect_bmm_scalar_dist_one @@ -570,7 +570,7 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, // AMD: scf.for // AMD: tt.load // AMD: tt.dot -// AMD: triton_gpu.local_store +// AMD: ttg.local_store // AMD: scf.yield // AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar_dist_one @@ -596,8 +596,8 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %85 = tt.splat %84 : i64 -> tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 @@ -607,30 +607,30 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, } // CHECK-LABEL: tt.func @indirect_bmm_vector -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.for // CHECK: tt.dot // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK-DAG: %[[IND_BUFFER_WAIT_TOKEN:.*]] = triton_gpu.async_wait {{.*}} {num = 1 : i32} -// CHECK-DAG: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview -// CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] token %[[IND_BUFFER_WAIT_TOKEN]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK-DAG: %[[IND_BUFFER_WAIT_TOKEN:.*]] = ttg.async_wait {{.*}} {num = 1 : i32} +// CHECK-DAG: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview +// CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] token %[[IND_BUFFER_WAIT_TOKEN]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] // CHECK: %[[IND_BUFFER_4:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_3]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] -// CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield // AMD-LABEL: tt.func @indirect_bmm_vector -// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc -// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc // AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} // AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] @@ -646,31 +646,31 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %[[MULI_13]] // AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_15]] -// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] -// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] // AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} // AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[LOAD_10]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]]) // AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} // AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} // AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] -// AMD: %[[LOCAL_LOAD_50:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_50:.*]] = ttg.local_load %[[ARG11]] // AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] // AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} // AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] // AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] // AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] // AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] -// AMD: %[[LOCAL_LOAD_57:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[LOCAL_LOAD_57:.*]] = ttg.local_load %[[ARG13]] // AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] // AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} // AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] -// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] // AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] // AMD_PREFETCH-LABEL: tt.func @indirect_bmm_vector @@ -696,8 +696,8 @@ tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> @@ -746,9 +746,9 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %114 = arith.cmpi "slt", %66, %113 : tensor<32x1xi32, #AL> %115 = tt.broadcast %114 : tensor<32x1xi1, #AL> -> tensor<32x32xi1, #AL> %116 = tt.load %arg12, %115, %cst_0 : tensor<32x32x!tt.ptr, #AL> - %117 = triton_gpu.convert_layout %112 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> - %118 = triton_gpu.convert_layout %116 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %119 = tt.dot %117, %118, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %117 = ttg.convert_layout %112 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %118 = ttg.convert_layout %116 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %119 = tt.dot %117, %118, %arg10 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %131 = arith.index_cast %arg9 : index to i32 %120 = arith.addi %131, %c1_i32 : i32 %121 = arith.muli %120, %c32_i32 : i32 @@ -764,7 +764,7 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // COMMON-LABEL: tt.func @cross_iter_dep // TODO: enable pipelining with distance of 2 -// COMMON-NOT: triton_gpu.async_commit_group +// COMMON-NOT: ttg.async_commit_group // COMMON: scf.for // COMMON: scf.yield @@ -804,9 +804,9 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %148 = arith.cmpi "slt", %88, %147 : tensor<32x1xi32, #AL> %149 = tt.broadcast %148 : tensor<32x1xi1, #AL> -> tensor<32x32xi1, #AL> %150 = tt.load %arg12, %149, %cst_1 : tensor<32x32x!tt.ptr, #AL> - %151 = triton_gpu.convert_layout %146 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> - %152 = triton_gpu.convert_layout %150 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %153 = tt.dot %151, %152, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %151 = ttg.convert_layout %146 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %152 = ttg.convert_layout %150 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %153 = tt.dot %151, %152, %arg10 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %162 = arith.index_cast %arg9 : index to i32 %154 = arith.addi %162, %c2_i32 : i32 %155 = arith.muli %154, %c32_i32 : i32 @@ -832,10 +832,10 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { %23 = arith.constant 100 : index %c64 = arith.constant 64 : i64 - %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %57 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %58 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> - %83 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %57 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %58 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %83 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> %85 = tt.splat %c64 : i64 -> tensor<1x32xi64, #AL> %86 = tt.splat %c64 : i64 -> tensor<1x32xi64, #AL> %68 = tt.splat %arg0 : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> @@ -848,17 +848,17 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %90 = tt.splat %c64 : i64 -> tensor<32x128xi64, #BL> %92 = tt.addptr %arg1, %c32_i32 : !tt.ptr, i32 %c0_index = arith.constant 0 : index - %91:5 = scf.for %arg19 = %c0_index to %23 step %c32_index iter_args(%arg20 = %68, %arg21 = %83, %arg22 = %92, %arg23 = %cst, %arg24 = %80) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL>) { + %91:5 = scf.for %arg19 = %c0_index to %23 step %c32_index iter_args(%arg20 = %68, %arg21 = %83, %arg22 = %92, %arg23 = %cst, %arg24 = %80) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL>) { %1750 = arith.subi %23, %arg19 : index %175 = arith.index_cast %1750 : index to i32 - %176 = tt.splat %175 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %177 = tt.splat %175 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> - %178 = arith.cmpi "slt", %57, %176 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %179 = arith.cmpi "slt", %58, %177 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> - %180 = tt.expand_dims %178 {axis = 0 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi1, #AL> - %181 = tt.expand_dims %179 {axis = 1 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 1, parent = #BL}>> -> tensor<32x1xi1, #BL> - %182 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> - %183 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> + %176 = tt.splat %175 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %177 = tt.splat %175 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %178 = arith.cmpi "slt", %57, %176 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %179 = arith.cmpi "slt", %58, %177 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %180 = tt.expand_dims %178 {axis = 0 : i32} : tensor<32xi1, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi1, #AL> + %181 = tt.expand_dims %179 {axis = 1 : i32} : tensor<32xi1, #ttg.slice<{dim = 1, parent = #BL}>> -> tensor<32x1xi1, #BL> + %182 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> + %183 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> %184 = arith.extsi %182 : tensor<1x32xi32, #AL> to tensor<1x32xi64, #AL> %185 = arith.extsi %183 : tensor<1x32xi32, #AL> to tensor<1x32xi64, #AL> %186 = arith.muli %184, %85 : tensor<1x32xi64, #AL> @@ -869,17 +869,17 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %191 = tt.addptr %arg20, %189 : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi64, #AL> %192 = tt.broadcast %180 : tensor<1x32xi1, #AL> -> tensor<128x32xi1, #AL> %193 = tt.load %191, %192 : tensor<128x32x!tt.ptr, #AL> - %194 = tt.splat %arg22 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %195 = tt.addptr %194, %56 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %196 = tt.load %195 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %194 = tt.splat %arg22 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>> + %195 = tt.addptr %194, %56 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %196 = tt.load %195 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>> %197 = tt.addptr %arg22, %c32_i32 : !tt.ptr, i32 %198 = tt.broadcast %181 : tensor<32x1xi1, #BL> -> tensor<32x128xi1, #BL> %199 = tt.load %arg24, %198, %88 : tensor<32x128x!tt.ptr, #BL> - %200 = triton_gpu.convert_layout %193 : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> - %201 = triton_gpu.convert_layout %199 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> - %202 = tt.dot %200, %201, %arg23 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> + %200 = ttg.convert_layout %193 : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> + %201 = ttg.convert_layout %199 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> + %202 = tt.dot %200, %201, %arg23 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> %203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi64, #BL> - scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> + scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> } tt.return %91#3 : tensor<128x128xf32, #C> } @@ -887,12 +887,12 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: tt.func @load_two_users_incompatible_layouts tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> @@ -907,16 +907,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> @@ -925,15 +925,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // COMMON: scf.for %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> - %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !triton_gpu.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> - %25 = triton_gpu.memdesc_trans %24 {order=array} : !triton_gpu.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> - %26 = triton_gpu.local_load %25 : !triton_gpu.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -943,22 +943,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: nested_loops -// CHECK: triton_gpu.local_alloc +// CHECK: ttg.local_alloc // CHECK: scf.for -// CHECK-NOT: triton_gpu.local_alloc +// CHECK-NOT: ttg.local_alloc // CHECK: scf.for // CHECK: scf.yield -// CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group // CHECK: scf.yield // AMD-LABEL: tt.func public @nested_loops // AMD: scf.for -// AMD: triton_gpu.local_alloc -// AMD-NOT: triton_gpu.local_alloc +// AMD: ttg.local_alloc +// AMD-NOT: ttg.local_alloc // AMD: scf.for // AMD: scf.yield // AMD-DIS: scf.yield @@ -979,9 +979,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // For CUDA, we pipeline the inner loop first then pipeline the outer // loop to prefetch the async copy after the inner loop. // For HIP, we only pipeline the inner loop for now. -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst_0 = arith.constant dense<320> : tensor<32x1xi32, #blocked> @@ -989,9 +989,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c1_i32 = arith.constant 1 : i32 %c32_i32 = arith.constant 32 : i32 %c10_i32 = arith.constant 10 : i32 - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %3 = arith.muli %2, %cst_0 : tensor<32x1xi32, #blocked> %4 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %5 = tt.addptr %4, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> @@ -1000,15 +1000,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %8 = tt.splat %arg3 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> scf.for %arg4 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { %9 = arith.muli %arg4, %c32_i32 : i32 - %10 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %11 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %12 = arith.addi %10, %0 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %13 = arith.addi %11, %1 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %14 = tt.expand_dims %12 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %10 = tt.splat %9 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %11 = tt.splat %9 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %12 = arith.addi %10, %0 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %13 = arith.addi %11, %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %14 = tt.expand_dims %12 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %15 = tt.broadcast %14 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %16 = tt.addptr %6, %15 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %17 = tt.load %16 : tensor<32x32x!tt.ptr, #blocked> - %18 = tt.expand_dims %13 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %18 = tt.expand_dims %13 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %19 = arith.muli %18, %cst_0 : tensor<32x1xi32, #blocked> %20 = tt.addptr %7, %19 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> %21 = tt.broadcast %20 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> @@ -1016,17 +1016,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %23 = tt.broadcast %22 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { %24 = arith.muli %arg5, %c32_i32 : i32 - %25 = tt.splat %24 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %26 = arith.addi %25, %0 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %25 = tt.splat %24 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %26 = arith.addi %25, %0 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %28 = tt.broadcast %27 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %29 = tt.addptr %21, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %30 = tt.load %29 : tensor<32x32x!tt.ptr, #blocked> - %31 = triton_gpu.convert_layout %30 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %32 = triton_gpu.convert_layout %17 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %33 = tt.dot %31, %32, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %31 = ttg.convert_layout %30 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %32 = ttg.convert_layout %17 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %33 = tt.dot %31, %32, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %34 = tt.addptr %23, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %35 = triton_gpu.convert_layout %33 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %35 = ttg.convert_layout %33 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %34, %35 : tensor<32x32x!tt.ptr, #blocked> } } @@ -1036,44 +1036,44 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// CHECK: #[[$SHARED_LAYOUT:shared.*]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> // CHECK-LABEL: tt.func @indirect_load_shared_layout // CHECK: scf.for // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview {{.*}} : !triton_gpu.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> -> !triton_gpu.memdesc<16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> -// CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #ttg.shared_memory, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #ttg.shared_memory, mutable> +// CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] // CHECK: %[[IND_BUFFER_4:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_3]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] -// CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} -// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> // AMD-LABEL: tt.func @indirect_load_shared_layout -// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc -// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc // AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) // AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} // AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} // AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] -// AMD: %[[LOCAL_LOAD_50:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_50:.*]] = ttg.local_load %[[ARG11]] // AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] // AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} // AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] // AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] // AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] // AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] -// AMD: %[[LOCAL_LOAD_57:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[LOCAL_LOAD_57:.*]] = ttg.local_load %[[ARG13]] // AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] // AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} // AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] -// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] // AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] // AMD: } // AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} @@ -1081,14 +1081,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: %[[ADDPTR_23:.*]] = tt.addptr %{{.*}}#1, %{{.*}} // AMD: %[[SPLAT_24:.*]] = tt.splat %[[CMPI_22]] // AMD: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_23]], %[[SPLAT_24]] -// AMD: %[[LOCAL_LOAD_26:.*]] = triton_gpu.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_26:.*]] = ttg.local_load %{{.*}}#4 // AMD: %[[EXPAND_DIMS_27:.*]] = tt.expand_dims %{{.*}}#5 {axis = 1 : i32} // AMD: %[[BROADCAST_28:.*]] = tt.broadcast %[[EXPAND_DIMS_27]] // AMD: %[[MULI_29:.*]] = arith.muli %{{.*}}, %[[BROADCAST_28]] // AMD: %[[ADDPTR_30:.*]] = tt.addptr %{{.*}}, %[[MULI_29]] // AMD: %[[SPLAT_31:.*]] = tt.splat %[[CMPI_22]] // AMD: %[[LOAD_32:.*]] = tt.load %[[ADDPTR_30]], %[[SPLAT_31]] -// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %{{.*}}#6 +// AMD: %[[LOCAL_LOAD_33:.*]] = ttg.local_load %{{.*}}#6 // AMD: %[[IF_34:.*]] = scf.if %[[CMPI_21]] // AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_26]], %[[LOCAL_LOAD_33]], %{{.*}}#0 // AMD: scf.yield %[[DOT_45]] @@ -1098,13 +1098,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} // AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} // AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_32]], %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_32]], %[[MEMDESC_SUBVIEW_39]] // AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_34]], %{{.*}}#0 -// AMD: %[[LOCAL_LOAD_41:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[LOCAL_LOAD_41:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[LOCAL_LOAD_42:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_39]] // AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] // AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] // AMD: scf.yield %[[DOT_45]] @@ -1112,16 +1112,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: scf.yield %[[SELECT_40]] // AMD: } // AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] -// AMD: triton_gpu.local_dealloc %{{.*}} -// AMD: triton_gpu.local_dealloc %{{.*}} - -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// AMD: ttg.local_dealloc %{{.*}} +// AMD: ttg.local_dealloc %{{.*}} + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -1143,8 +1143,8 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> @@ -1158,25 +1158,25 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit // ----- // CHECK-LABEL: @kernel_yield_constant -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview // CHECK: scf.for -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview // CHECK: tt.return // AMD-LABEL: @kernel_yield_constant // AMD: tt.load -// AMD: triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store +// AMD: ttg.memdesc_subview +// AMD: ttg.local_store // AMD: scf.for // AMD: tt.load -// AMD: triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store +// AMD: ttg.memdesc_subview +// AMD: ttg.local_store // AMD: tt.return -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @kernel_yield_constant(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst1 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #mma> @@ -1185,12 +1185,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> %c32_i32 = arith.constant 32 : i32 %c31_i32 = arith.constant 31 : i32 - %cst_1 = arith.constant dense<2.000000e+00> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %cst_1 = arith.constant dense<2.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %0 = tt.get_program_id x : i32 - %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %12 = arith.addi %arg4, %c31_i32 : i32 %13 = arith.divsi %12, %c32_i32 : i32 - %14 = tt.expand_dims %7 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %14 = tt.expand_dims %7 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %22 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> %34 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> %42 = scf.for %arg7 = %c0_i32 to %13 step %c1_i32 iter_args(%arg8 = %cst) -> (tensor<32x32xf32, #mma>) : i32 { @@ -1203,9 +1203,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %49 = arith.cmpi slt, %14, %48 : tensor<32x1xi32, #blocked> %50 = tt.broadcast %49 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> %51 = tt.load %46, %50, %cst_0 : tensor<32x32x!tt.ptr, #blocked> - %52 = triton_gpu.convert_layout %51 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %53 = tt.dot %cst_1, %52, %arg8 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - %54 = triton_gpu.convert_layout %53 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %52 = ttg.convert_layout %51 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %53 = tt.dot %cst_1, %52, %arg8 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %54 = ttg.convert_layout %53 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %34, %54 : tensor<32x32x!tt.ptr, #blocked> scf.yield %cst1 : tensor<32x32xf32, #mma> } @@ -1219,16 +1219,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: @add_kernel // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[A0BUFFER]] -// CHECK: %[[B0BUFFER:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[B0BUFFER]] -// CHECK: %[[A1BUFFER:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[A1BUFFER]] -// CHECK: %[[B1BUFFER:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: %[[A0BUFFER:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[A0BUFFER]] +// CHECK: %[[B0BUFFER:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[B0BUFFER]] +// CHECK: %[[A1BUFFER:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[A1BUFFER]] +// CHECK: %[[B1BUFFER:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] // CHECK: scf.for // AMD-LABEL: tt.func public @add_kernel @@ -1244,8 +1244,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: %[[ADDPTR_20:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] // AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[CMPI_17]] // AMD: scf.for -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -1280,95 +1280,95 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: @nested_loops // CHECK: tt.addptr %{{.*}}, {{.*}} // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: %[[BUFFER_1:.*]] = triton_gpu.local_alloc -// CHECK: %[[SUBVIEW_1:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_1:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] -// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_1]] -// CHECK: %[[SUBVIEW_2:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_2:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] -// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_2]] +// CHECK: %[[BUFFER_1:.*]] = ttg.local_alloc +// CHECK: %[[SUBVIEW_1:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_1:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_1]] +// CHECK: %[[SUBVIEW_2:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_2:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_2]] // CHECK: scf.for // CHECK: %[[LOAD_1:.*]] = tt.load %[[NEXT_BUFFER_1]] -// CHECK: %[[BUFFER_2:.*]] = triton_gpu.local_alloc %[[LOAD_1]] -// CHECK: %[[TRANS:.*]] = triton_gpu.memdesc_trans %[[BUFFER_2]] -// CHECK: %[[LOCAL_LOAD_1:.*]] = triton_gpu.local_load %[[TRANS]] -// CHECK: triton_gpu.async_wait -// CHECK: triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[BUFFER_2:.*]] = ttg.local_alloc %[[LOAD_1]] +// CHECK: %[[TRANS:.*]] = ttg.memdesc_trans %[[BUFFER_2]] +// CHECK: %[[LOCAL_LOAD_1:.*]] = ttg.local_load %[[TRANS]] +// CHECK: ttg.async_wait +// CHECK: ttg.memdesc_subview %[[BUFFER_1]] // CHECK: scf.for -// CHECK: %[[LOCAL_LOAD_2:.*]] = triton_gpu.local_load +// CHECK: %[[LOCAL_LOAD_2:.*]] = ttg.local_load // CHECK: %[[DOT:.*]] = tt.dot %[[LOCAL_LOAD_2]], %[[LOCAL_LOAD_1]] -// CHECK: %[[CONVERT_LAYOUT_3:.*]] = triton_gpu.convert_layout %[[DOT]] -// CHECK: %[[SUBVIEW_4:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_3:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_4]] -// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_3]] -// CHECK: triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[SUBVIEW_6:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_4:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_6]] mask -// CHECK: %[[COMMIT_1:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_4]] -// CHECK: %[[SUBVIEW_7:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_5:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_7]] mask -// CHECK: %[[COMMIT_2:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_5]] +// CHECK: %[[CONVERT_LAYOUT_3:.*]] = ttg.convert_layout %[[DOT]] +// CHECK: %[[SUBVIEW_4:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_3:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_4]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_3]] +// CHECK: ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[SUBVIEW_6:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_4:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_6]] mask +// CHECK: %[[COMMIT_1:.*]] = ttg.async_commit_group %[[ASYNC_COPY_4]] +// CHECK: %[[SUBVIEW_7:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_5:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_7]] mask +// CHECK: %[[COMMIT_2:.*]] = ttg.async_commit_group %[[ASYNC_COPY_5]] // CHECK: scf.yield %[[COMMIT_1]], %[[COMMIT_2]] -// CHECK: triton_gpu.local_dealloc %[[BUFFER_1]] +// CHECK: ttg.local_dealloc %[[BUFFER_1]] // AMD-LABEL: tt.func public @nested_loops -// AMD-NOT: triton_gpu.local_alloc +// AMD-NOT: ttg.local_alloc // AMD: scf.for -// AMD: triton_gpu.local_alloc +// AMD: ttg.local_alloc // AMD: scf.for -// AMD: triton_gpu.local_load +// AMD: ttg.local_load // AMD: tt.dot -// AMD: triton_gpu.local_store +// AMD: ttg.local_store // AMD: scf.yield -// AMD: triton_gpu.local_dealloc +// AMD: ttg.local_dealloc // AMD_PREFETCH-LABEL: tt.func public @nested_loops -// AMD_PREFETCH-NOT: triton_gpu.local_alloc +// AMD_PREFETCH-NOT: ttg.local_alloc // AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: ttg.local_alloc // AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: ttg.local_store // AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: ttg.local_load // AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: ttg.local_store // AMD_PREFETCH: tt.load // AMD_PREFETCH: tt.dot -// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: ttg.local_load // AMD_PREFETCH: scf.yield -// AMD_PREFETCH: triton_gpu.local_dealloc +// AMD_PREFETCH: ttg.local_dealloc -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 %c0_i32 = arith.constant 0 : i32 %cst_0 = arith.constant dense<16> : tensor<16x1xi32, #blocked> - %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> %2 = arith.muli %1, %cst_0 : tensor<16x1xi32, #blocked> %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked> %4 = tt.addptr %3, %2 : tensor<16x1x!tt.ptr, #blocked>, tensor<16x1xi32, #blocked> - %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> %7 = tt.broadcast %4 : tensor<16x1x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked> %8 = tt.broadcast %6 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> - %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !triton_gpu.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> - %12 = triton_gpu.memdesc_trans %11 {order = array} : !triton_gpu.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !triton_gpu.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> - %13 = triton_gpu.local_load %12 : !triton_gpu.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %11 = ttg.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !ttg.memdesc<16x16xf32, #shared, #ttg.shared_memory> + %12 = ttg.memdesc_trans %11 {order = array} : !ttg.memdesc<16x16xf32, #shared, #ttg.shared_memory> -> !ttg.memdesc<16x16xf32, #shared1, #ttg.shared_memory> + %13 = ttg.local_load %12 : !ttg.memdesc<16x16xf32, #shared1, #ttg.shared_memory> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> - %15 = triton_gpu.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %16 = tt.dot %15, %13, %cst : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma> - %17 = triton_gpu.convert_layout %16 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked> + %15 = ttg.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %16 = tt.dot %15, %13, %cst : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma> + %17 = ttg.convert_layout %16 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked> tt.store %9, %17 : tensor<16x16x!tt.ptr, #blocked> } } @@ -1379,14 +1379,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- // CHECK-LABEL: @int4_matmul_ampere -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [16, 1, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [1, 8, 1], order = [2, 0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [16, 2, 1], threadsPerWarp = [4, 1, 8], warpsPerCTA = [1, 1, 8], order = [1, 0, 2]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [16, 1, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [1, 8, 1], order = [2, 0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [16, 2, 1], threadsPerWarp = [4, 1, 8], warpsPerCTA = [1, 1, 8], order = [1, 0, 2]}> +#blocked5 = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { tt.func public @int4_matmul_ampere( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32} @@ -1404,14 +1404,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %cst_2 = arith.constant dense<4> : tensor<64x256xi8, #blocked> %cst_3 = arith.constant dense<0.000000e+00> : tensor<16x256xf32, #mma> - %35 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %36 = tt.expand_dims %35 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %35 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %36 = tt.expand_dims %35 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> %38 = tt.broadcast %36 : tensor<1x128xi32, #blocked1> -> tensor<16x128xi32, #blocked1> %40 = tt.splat %arg0 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked1> %41 = tt.addptr %40, %38 : tensor<16x128x!tt.ptr, #blocked1>, tensor<16x128xi32, #blocked1> - %42 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %43 = tt.expand_dims %42 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %42 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %43 = tt.expand_dims %42 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %47 = tt.broadcast %43 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> %50 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> %51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> @@ -1419,9 +1419,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // Check that both loads in the loop are pipelined. // CHECK: scf.for // CHECK-NOT: tt.load - // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: ttg.async_copy_global_to_local // CHECK-NOT: tt.load - // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: ttg.async_copy_global_to_local // CHECK-NOT: tt.load // CHECK: scf.yield %54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>) : i32 { @@ -1435,9 +1435,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3> %86 = tt.trans %85 {order = array} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4> %87 = tt.reshape %86 : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> - %88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> + %88 = ttg.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %89 = ttg.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> %91 = tt.addptr %arg11, %cst_0 : tensor<16x128x!tt.ptr, #blocked1>, tensor<16x128xi32, #blocked1> %92 = tt.addptr %arg12, %cst : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> scf.yield %90, %91, %92 : tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked> @@ -1452,16 +1452,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // This test triggered some failure in the verifier, so we only // included a simple check for the kernel name. // COMMON-LABEL: @load_convert_layout -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -1486,8 +1486,8 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> @@ -1503,18 +1503,18 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 // This test captured some ICE in MatmulLoopPipeline pass, so we only // included a simple check for the kernel name. // COMMON-LABEL: @matmul_indirect_pipeline -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func public @matmul_indirect_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 %c0_i32 = arith.constant 0 : i32 - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> - %3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %4 = tt.broadcast %2 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> %5 = tt.broadcast %3 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %6 = arith.addi %4, %5 : tensor<32x32xi32, #blocked> @@ -1523,20 +1523,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : %9 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> %10 = tt.splat %arg3 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> %11 = tt.addptr %10, %6 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %12 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %13 = tt.addptr %12, %0 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %13 = tt.addptr %12, %0 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { - %15 = tt.load %13 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %16 = tt.addptr %14, %15 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %17 = tt.load %16 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %18 = tt.expand_dims %17 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xf32, #blocked> + %15 = tt.load %13 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %16 = tt.addptr %14, %15 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %17 = tt.load %16 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %18 = tt.expand_dims %17 {axis = 0 : i32} : tensor<32xf32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xf32, #blocked> %19 = tt.broadcast %18 : tensor<1x32xf32, #blocked> -> tensor<32x32xf32, #blocked> %20 = arith.addf %9, %19 : tensor<32x32xf32, #blocked> - %21 = triton_gpu.convert_layout %9 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %22 = triton_gpu.convert_layout %20 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - %24 = triton_gpu.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %21 = ttg.convert_layout %9 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %22 = ttg.convert_layout %20 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %24 = ttg.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %11, %24 : tensor<32x32x!tt.ptr, #blocked> } {tt.num_stages = 3 : i32} tt.return @@ -1548,20 +1548,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // COMMON-LABEL: @dont_pipeline_128x1 // AMD-NOT: local_load{{.*}}128x1 // CHECK: local_load{{.*}}128x1 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @dont_pipeline_128x1(%arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %c128_i32 = arith.constant 128 : i32 %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 - %cst_4 = arith.constant dense<-1.000000e+30> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_4 = arith.constant dense<-1.000000e+30> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> - %99:1 = scf.for %arg25 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg31 = %cst_4) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) : i32 { + %99:1 = scf.for %arg25 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg31 = %cst_4) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) : i32 { %94 = tt.splat %arg6 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> %151 = tt.load %94 : tensor<128x1x!tt.ptr, #blocked> - %161 = triton_gpu.convert_layout %151 : tensor<128x1xi32, #blocked> -> tensor<128x1xi32, #mma> + %161 = ttg.convert_layout %151 : tensor<128x1xi32, #blocked> -> tensor<128x1xi32, #mma> %162 = tt.broadcast %161 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma> %170 = arith.sitofp %162 : tensor<128x64xi32, #mma> to tensor<128x64xf32, #mma> @@ -1569,17 +1569,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : ^bb0(%arg33: f32, %arg34: f32): %207 = arith.maxnumf %arg33, %arg34 : f32 tt.reduce.return %207 : f32 - }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %175 = arith.maxnumf %arg31, %173 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %175 = arith.maxnumf %arg31, %173 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %201 = arith.truncf %170 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - %202 = triton_gpu.convert_layout %201 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %202 = ttg.convert_layout %201 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %192 = arith.constant dense<0.> : tensor<128x64xf32, #mma> - %203 = arith.constant dense<0.> : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %204 = tt.dot %202, %203, %192 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %203 = arith.constant dense<0.> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %204 = tt.dot %202, %203, %192 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> - scf.yield %175 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + scf.yield %175 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> } tt.return } @@ -1590,18 +1590,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that the dependencies across ops of different nesting does not cause crash or // incorrect schedule that fails to pipeline. // COMMON-LABEL: @matmul_nested_ops -// COMMON: triton_gpu.local_load - -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> - -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// COMMON: ttg.local_load + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}, @@ -1628,7 +1628,7 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x128xf32, #C>) { %cnd = arith.cmpi slt, %iv, %ext : index @@ -1639,7 +1639,7 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, scf.yield %a_ptr : tensor<128x32x!tt.ptr, #AL> } %a_ = tt.load %inc_a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -1655,9 +1655,9 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, // CHECK-LABEL: @masked_add_kernel // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> // CHECK: scf.for -// CHECK: %[[A:.*]] = triton_gpu.local_load +// CHECK: %[[A:.*]] = ttg.local_load // CHECK: arith.select {{.*}}, %[[A]], %[[CONSTANT]] -// CHECK: %[[B:.*]] = triton_gpu.local_load +// CHECK: %[[B:.*]] = ttg.local_load // CHECK: arith.select {{.*}}, %[[B]], %[[CONSTANT]] // AMD-LABEL: @masked_add_kernel @@ -1687,8 +1687,8 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, // AMD_PREFETCH: tt.store // AMD_PREFETCH: tt.store -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @masked_add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 diff --git a/test/TritonGPU/loop-schedule.mlir b/test/TritonGPU/loop-schedule.mlir index a73f70d3e3..adf7050da3 100644 --- a/test/TritonGPU/loop-schedule.mlir +++ b/test/TritonGPU/loop-schedule.mlir @@ -1,14 +1,14 @@ // RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 | FileCheck %s -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#CLs0 = #triton_gpu.slice<{parent=#C, dim=0}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#CLs0 = #ttg.slice<{parent=#C, dim=0}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABLE: @matmul_loop_load_acc // CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} // CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} @@ -45,9 +45,9 @@ tt.func @matmul_loop_load_acc(%lb : index, %ub : index, %step : index, %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %c_ptr = %c_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128x!tt.ptr, #C>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c_ = tt.load %c_ptr : tensor<128x128x!tt.ptr, #C> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> diff --git a/test/TritonGPU/matmul-loop-pipeline.mlir b/test/TritonGPU/matmul-loop-pipeline.mlir index 84a857fc72..f8042feee9 100644 --- a/test/TritonGPU/matmul-loop-pipeline.mlir +++ b/test/TritonGPU/matmul-loop-pipeline.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -tritongpu-pipeline | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @softmax_kernel tt.func public @softmax_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { @@ -17,7 +17,7 @@ tt.func public @softmax_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, scf.for %arg6 = %0 to %arg4 step %1 : i32 { %5 = tt.splat %arg1 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.ptr -> tensor<128x!tt.ptr, #blocked> %6 = tt.addptr %5, %2 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked> - // CHECK: [[RESULT:%.*]] = triton_gpu.local_load + // CHECK: [[RESULT:%.*]] = ttg.local_load // CHECK-NEXT: arith.select [[MASK]], [[RESULT]], %cst %7 = tt.load %6, %4, %cst {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr, #blocked> %8 = tt.splat %arg0 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : !tt.ptr -> tensor<128x!tt.ptr, #blocked> diff --git a/test/TritonGPU/ops.mlir b/test/TritonGPU/ops.mlir index 70c1a315e7..1513ac60e8 100644 --- a/test/TritonGPU/ops.mlir +++ b/test/TritonGPU/ops.mlir @@ -1,49 +1,49 @@ // RUN: triton-opt --split-input-file %s | FileCheck %s -// CHECK: #[[$WMMA_GEN1:.*]] = #triton_gpu.amd_wmma<{{.*}}version = 1{{.*}}> -// CHECK: #[[$WMMA_GEN2:.*]] = #triton_gpu.amd_wmma<{{.*}}version = 2{{.*}}> -#blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +// CHECK: #[[$WMMA_GEN1:.*]] = #ttg.amd_wmma<{{.*}}version = 1{{.*}}> +// CHECK: #[[$WMMA_GEN2:.*]] = #ttg.amd_wmma<{{.*}}version = 2{{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_layout tt.func @wmma_layout(%0: tensor<16x16xf16, #blocked>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN1]]> + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN1]]> tt.return } // CHECK-LABEL: wmma_dot_op_layout - tt.func @wmma_dot_op_layout(%0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>, kWidth = 16}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN1]], kWidth = 16}>> + tt.func @wmma_dot_op_layout(%0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) { + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>, kWidth = 16}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN1]], kWidth = 16}>> tt.return } // CHECK-LABEL: wmma_gen2_layout tt.func @wmma_gen2_layout(%0: tensor<16x16xf16, #blocked>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN2]]> + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN2]]> tt.return } // CHECK-LABEL: wmma_gen2_dot_op_layout - tt.func @wmma_gen2_dot_op_layout(%0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>, kWidth = 8}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN2]], kWidth = 8}>> + tt.func @wmma_gen2_dot_op_layout(%0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) { + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>, kWidth = 8}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN2]], kWidth = 8}>> tt.return } } // ----- -#blocked= #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[$LINEAR:.*]] = #triton_gpu.linear<{{.*}}> +#blocked= #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[$LINEAR:.*]] = #ttg.linear<{{.*}}> -module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @blocked_to_linear tt.func @blocked_to_linear(%input: tensor<32x4xi8, #blocked>) { // The layout is the basic layout generated by DecomposeScaledBlocked - %output = triton_gpu.convert_layout %input {allocation.offset = 0 : i32} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #triton_gpu.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #[[$LINEAR]]> + %output = ttg.convert_layout %input {allocation.offset = 0 : i32} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #[[$LINEAR]]> tt.return } } diff --git a/test/TritonGPU/optimize-locality.mlir b/test/TritonGPU/optimize-locality.mlir index 5442998671..25fa8fbcb5 100644 --- a/test/TritonGPU/optimize-locality.mlir +++ b/test/TritonGPU/optimize-locality.mlir @@ -10,11 +10,11 @@ // CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @negative_zero_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -23,16 +23,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<-0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<-0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -40,11 +40,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.addf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -63,11 +63,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: arith.addf %[[CVT_OUTPUT]], %[[CST]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @positive_zero_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -76,16 +76,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -93,11 +93,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.addf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -112,11 +112,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: arith.addf // CHECK: arith.addf // CHECK-NEXT: scf.yield -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[LOOP_OUTPUT]] -#blocked3d = #triton_gpu.blocked<{sizePerThread = [1, 4, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> -#slice2d = #triton_gpu.slice<{dim = 2, parent = #blocked3d}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[LOOP_OUTPUT]] +#blocked3d = #ttg.blocked<{sizePerThread = [1, 4, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#slice2d = #ttg.slice<{dim = 2, parent = #blocked3d}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @slice_layout( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -125,16 +125,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> -> tensor<1x128xi32, #slice2d> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> -> tensor<1x128xi32, #slice2d> %31 = tt.broadcast %30 : tensor<1x128xi32, #slice2d> -> tensor<32x128xi32, #slice2d> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #slice2d>, tensor<32x128xi32, #slice2d> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #slice2d> @@ -142,11 +142,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.addf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #slice2d>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> - %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> + }) : (tensor<32x128xf32, #slice2d>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -161,11 +161,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: arith.addf // CHECK: arith.addf // CHECK-NEXT: scf.yield -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[LOOP_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[LOOP_OUTPUT]] +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @mma_layout( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -174,16 +174,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma> %31 = tt.broadcast %30 : tensor<1x128xi32, #mma> -> tensor<32x128xi32, #mma> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #mma>, tensor<32x128xi32, #mma> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #mma> @@ -191,11 +191,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.addf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + }) : (tensor<32x128xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -214,11 +214,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.maximumf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @max_reduce( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -227,16 +227,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0xFF800000> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0xFF800000> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -244,11 +244,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.maximumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -268,11 +268,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.maximumf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: arith.maximumf %[[CVT_OUTPUT]], %[[CST]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @max_reduce_zero_int_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -281,16 +281,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -298,11 +298,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.maximumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -321,11 +321,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.minimumf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @min_reduce( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -334,16 +334,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0x7F800000> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0x7F800000> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -351,11 +351,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.minimumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -375,11 +375,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.minimumf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: arith.minimumf %[[CVT_OUTPUT]], %[[CST]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @min_reduce_zero_int_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -388,16 +388,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -405,11 +405,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.minimumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -428,11 +428,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.mulf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @mul_reduce( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -441,16 +441,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -458,11 +458,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.mulf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.mulf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.mulf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -482,11 +482,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.mulf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: arith.mulf %[[CVT_OUTPUT]], %[[CST]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @mul_reduce_zero_int_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -495,16 +495,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -512,11 +512,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.mulf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.mulf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.mulf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -534,9 +534,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: arith.maximumf // CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] // CHECK-NEXT: scf.yield -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @remains_unchanged( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -545,16 +545,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -563,11 +563,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.maximumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -575,32 +575,32 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK-DAG: #[[$BLOCK0:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCK1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCK2:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +// CHECK-DAG: #[[$BLOCK0:.+]] = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCK1:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCK2:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> // CHECK-LABEL: optimize_view_layout // CHECK: %[[R:.+]] = tt.reshape {{.*}} allow_reorder efficient_layout : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> -// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK2]]> -> tensor<64x16xf32, #[[$BLOCK1]]> +// CHECK: %[[C:.+]] = ttg.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK2]]> -> tensor<64x16xf32, #[[$BLOCK1]]> // CHECK: "tt.reduce"(%[[C]]) -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @optimize_view_layout(%arg0: tensor<8x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @optimize_view_layout(%arg0: tensor<8x128xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> { %0 = tt.reshape %arg0 allow_reorder : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %2 = arith.maximumf %arg1, %arg2 : f32 tt.reduce.return %2 : f32 - }) : (tensor<64x16xf32, #blocked1>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - tt.return %1 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + }) : (tensor<64x16xf32, #blocked1>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + tt.return %1 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#slice = #ttg.slice<{dim = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func public @reduce_for_arg(%arg: tensor<64x128xf32, #blocked>, %arg1: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 %c128_i32 = arith.constant 128 : i32 @@ -611,8 +611,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : ^bb0(%arg31: f32, %arg32: f32): %160 = arith.maxnumf %arg31, %arg32 : f32 tt.reduce.return %160 : f32 - }) : (tensor<64x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %75 = triton_gpu.convert_layout %129 : tensor<64xf32, #slice> -> tensor<64xf32, #blocked1> + }) : (tensor<64x128xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %75 = ttg.convert_layout %129 : tensor<64xf32, #slice> -> tensor<64xf32, #blocked1> %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> %80 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> %81 = tt.addptr %80, %79 : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> diff --git a/test/TritonGPU/optimize_epilogue.mlir b/test/TritonGPU/optimize_epilogue.mlir index d990b14e85..142ec762fb 100644 --- a/test/TritonGPU/optimize_epilogue.mlir +++ b/test/TritonGPU/optimize_epilogue.mlir @@ -1,14 +1,14 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-optimize-epilogue | FileCheck --check-prefixes=GCN %s -#mfma = #triton_gpu.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { // GCN-LABEL: mfma_epilogue_simple // CHECK-LABEL: mfma_epilogue_simple tt.func public @mfma_epilogue_simple(%data: tensor<64x64xf16, #mfma>, %ptr: tensor<64x64x!tt.ptr, #blocked>) { - // GCN: [[PTR:%[a-z0-9]+]] = triton_gpu.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> + // GCN: [[PTR:%[a-z0-9]+]] = ttg.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> // GCN: tt.store [[PTR]], {{.*}} : tensor<{{.*}}, #mma> - %converted_data = triton_gpu.convert_layout %data : tensor<64x64xf16, #mfma> -> tensor<64x64xf16, #blocked> + %converted_data = ttg.convert_layout %data : tensor<64x64xf16, #mfma> -> tensor<64x64xf16, #blocked> tt.store %ptr, %converted_data : tensor<64x64x!tt.ptr, #blocked> tt.return } @@ -16,15 +16,15 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- -#mfma = #triton_gpu.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { // GCN-LABEL: mfma_epilogue_chained_elementwise // CHECK-LABEL: mfma_epilogue_chained_elementwise tt.func public @mfma_epilogue_chained_elementwise(%data: tensor<64x64xf32, #mfma>, %ptr: tensor<64x64x!tt.ptr, #blocked>) { - // GCN: [[PTR:%[a-z0-9]+]] = triton_gpu.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> + // GCN: [[PTR:%[a-z0-9]+]] = ttg.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> // GCN: tt.store [[PTR]], {{.*}} : tensor<{{.*}}, #mma> - %converted_data = triton_gpu.convert_layout %data : tensor<64x64xf32, #mfma> -> tensor<64x64xf32, #blocked> + %converted_data = ttg.convert_layout %data : tensor<64x64xf32, #mfma> -> tensor<64x64xf32, #blocked> %trunked = arith.truncf %converted_data : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> tt.store %ptr, %trunked : tensor<64x64x!tt.ptr, #blocked> tt.return diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index cd93be2c47..03c2e07323 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -1,21 +1,21 @@ // RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: two_dependent_dot tt.func public @two_dependent_dot(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 - %cst_1 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_1 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %cst_4 = arith.constant 1.44269502 : f32 @@ -34,25 +34,25 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %9 = arith.extsi %3 : i32 to i64 %10 = arith.extsi %c0_i32 : i32 to i64 %11 = arith.muli %0, %c128_i32 : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> - %15 = tt.splat %11 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %16 = tt.splat %11 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %15 = tt.splat %11 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.splat %11 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> %17 = tt.splat %11 : i32 -> tensor<128xi32, #blocked1> - %18 = arith.addi %15, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %19 = arith.addi %16, %13 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %18 = arith.addi %15, %12 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %19 = arith.addi %16, %13 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> %20 = arith.addi %17, %14 : tensor<128xi32, #blocked1> %21 = arith.mulf %arg3, %cst_4 : f32 %22 = tt.addptr %arg0, %2 : !tt.ptr, i32 - %23 = tt.expand_dims %18 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %24 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> + %23 = tt.expand_dims %18 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %24 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> %25 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked> %26 = arith.muli %23, %25 : tensor<128x1xi32, #blocked> %27 = tt.splat %22 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> %28 = tt.addptr %27, %26 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %29 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %29 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %28 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> %32 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> %33 = tt.addptr %31, %32 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> @@ -63,23 +63,23 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %38 = arith.truncf %37 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> %39 = arith.addi %0, %c1_i32 : i32 %40 = arith.muli %39, %c128_i32 : i32 - %41:7 = scf.for %arg22 = %c0_i32 to %40 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %7, %arg28 = %9, %arg29 = %10) -> (tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64) : i32 { + %41:7 = scf.for %arg22 = %c0_i32 to %40 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %7, %arg28 = %9, %arg29 = %10) -> (tensor<128x128xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64) : i32 { %69 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked2> - %70 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %71 = arith.extsi %70 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %72 = tt.splat %arg26 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %73 = arith.addi %71, %72 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %74 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi64, #blocked2> + %70 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %71 = arith.extsi %70 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %72 = tt.splat %arg26 : i64 -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %73 = arith.addi %71, %72 : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %74 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi64, #blocked2> %75 = tt.broadcast %74 : tensor<128x1xi64, #blocked2> -> tensor<128x64xi64, #blocked2> %76 = tt.splat %c1_i64 : i64 -> tensor<128x64xi64, #blocked2> %77 = arith.muli %75, %76 : tensor<128x64xi64, #blocked2> %78 = tt.broadcast %77 : tensor<128x64xi64, #blocked2> -> tensor<128x64xi64, #blocked2> %79 = tt.addptr %69, %78 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi64, #blocked2> - %80 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %81 = arith.extsi %80 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %82 = tt.splat %arg27 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %83 = arith.addi %81, %82 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %84 = tt.expand_dims %83 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> + %80 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %81 = arith.extsi %80 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %82 = tt.splat %arg27 : i64 -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.addi %81, %82 : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> %85 = tt.broadcast %84 : tensor<1x64xi64, #blocked2> -> tensor<128x64xi64, #blocked2> %86 = tt.splat %5 : i64 -> tensor<128x64xi64, #blocked2> %87 = arith.muli %85, %86 : tensor<128x64xi64, #blocked2> @@ -87,43 +87,43 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %89 = tt.addptr %79, %88 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi64, #blocked2> %90 = tt.load %89 : tensor<128x64x!tt.ptr, #blocked2> %91 = tt.splat %arg2 : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked> - %92 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %93 = arith.extsi %92 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %94 = tt.splat %arg28 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %95 = arith.addi %93, %94 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %96 = tt.expand_dims %95 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi64, #blocked> + %92 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %93 = arith.extsi %92 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %94 = tt.splat %arg28 : i64 -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %95 = arith.addi %93, %94 : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %96 = tt.expand_dims %95 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi64, #blocked> %97 = tt.broadcast %96 : tensor<64x1xi64, #blocked> -> tensor<64x128xi64, #blocked> %98 = tt.splat %8 : i64 -> tensor<64x128xi64, #blocked> %99 = arith.muli %97, %98 : tensor<64x128xi64, #blocked> %100 = tt.broadcast %99 : tensor<64x128xi64, #blocked> -> tensor<64x128xi64, #blocked> %101 = tt.addptr %91, %100 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi64, #blocked> - %102 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %103 = arith.extsi %102 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %104 = tt.splat %arg29 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %105 = arith.addi %103, %104 : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %106 = tt.expand_dims %105 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> + %102 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %103 = arith.extsi %102 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %104 = tt.splat %arg29 : i64 -> tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %105 = arith.addi %103, %104 : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %106 = tt.expand_dims %105 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> %107 = tt.broadcast %106 : tensor<1x128xi64, #blocked> -> tensor<64x128xi64, #blocked> %108 = tt.splat %c1_i64 : i64 -> tensor<64x128xi64, #blocked> %109 = arith.muli %107, %108 : tensor<64x128xi64, #blocked> %110 = tt.broadcast %109 : tensor<64x128xi64, #blocked> -> tensor<64x128xi64, #blocked> %111 = tt.addptr %101, %110 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi64, #blocked> %112 = tt.load %111 : tensor<64x128x!tt.ptr, #blocked> - %113 = triton_gpu.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !triton_gpu.memdesc<128x128xf16, #shared> - %114 = triton_gpu.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !triton_gpu.memdesc<128x64xf16, #shared1> - %115 = triton_nvidia_gpu.warp_group_dot %113, %114, %cst :!triton_gpu.memdesc<128x128xf16, #shared> * !triton_gpu.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %113 = ttg.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared> + %114 = ttg.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1> + %115 = ttng.warp_group_dot %113, %114, %cst :!ttg.memdesc<128x128xf16, #shared> * !ttg.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !triton_gpu.memdesc<64x128xf16, #shared> - %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %117 = ttg.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared> + %118 = ttg.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK: ttng.warp_group_dot + // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait // CHECK: scf.yield - %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !triton_gpu.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> - %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %119 = ttng.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %121 = arith.addf %120, %arg25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 %123 = arith.addi %arg26, %122 : i64 %124 = arith.extsi %c64_i32 : i32 to i64 @@ -132,30 +132,30 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %127 = arith.addi %arg28, %126 : i64 %128 = arith.extsi %c0_i32 : i32 to i64 %129 = arith.addi %arg29, %128 : i64 - scf.yield %119, %121, %arg25, %123, %125, %127, %129 : tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64 + scf.yield %119, %121, %arg25, %123, %125, %127, %129 : tensor<128x128xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64 } %42 = arith.addi %3, %11 : i32 %43 = arith.extsi %arg17 : i32 to i64 %44 = arith.extsi %42 : i32 to i64 %45 = arith.extsi %c0_i32 : i32 to i64 %46 = arith.truncf %41#0 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1> - %47 = triton_gpu.convert_layout %46 : tensor<128x128xf16, #mma1> -> tensor<128x128xf16, #blocked> + %47 = ttg.convert_layout %46 : tensor<128x128xf16, #mma1> -> tensor<128x128xf16, #blocked> %48 = tt.splat %arg5 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> - %49 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %50 = arith.extsi %49 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %51 = tt.splat %44 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %52 = arith.addi %50, %51 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %53 = tt.expand_dims %52 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> + %49 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %50 = arith.extsi %49 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %51 = tt.splat %44 : i64 -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %52 = arith.addi %50, %51 : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %53 = tt.expand_dims %52 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> %54 = tt.broadcast %53 : tensor<128x1xi64, #blocked> -> tensor<128x128xi64, #blocked> %55 = tt.splat %43 : i64 -> tensor<128x128xi64, #blocked> %56 = arith.muli %54, %55 : tensor<128x128xi64, #blocked> %57 = tt.broadcast %56 : tensor<128x128xi64, #blocked> -> tensor<128x128xi64, #blocked> %58 = tt.addptr %48, %57 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi64, #blocked> - %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %60 = arith.extsi %59 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %61 = tt.splat %45 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %62 = arith.addi %60, %61 : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %63 = tt.expand_dims %62 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %60 = arith.extsi %59 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %61 = tt.splat %45 : i64 -> tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %62 = arith.addi %60, %61 : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %63 = tt.expand_dims %62 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> %64 = tt.broadcast %63 : tensor<1x128xi64, #blocked> -> tensor<128x128xi64, #blocked> %65 = tt.splat %c1_i64 : i64 -> tensor<128x128xi64, #blocked> %66 = arith.muli %64, %65 : tensor<128x128xi64, #blocked> diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 1c0eeeb666..ab070a73a5 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -2,38 +2,38 @@ // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_OP = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_OP = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> // CHECK: tt.func @matmul_loop_mixed // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.local_load %[[B0_PREFETCH_SMEM]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] -// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.local_load %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] // CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] -// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] -// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.local_load %[[B_REM_SMEM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] -module attributes { "triton_gpu.num-warps" = 4 : i32 } { +module attributes { "ttg.num-warps" = 4 : i32 } { tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> @@ -48,24 +48,24 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a_init = triton_gpu.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !triton_gpu.memdesc<128x32xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b_init = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !triton_gpu.memdesc<32x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !triton_gpu.memdesc<128x32xf8E5M2, #A>, !triton_gpu.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = triton_gpu.local_load %a : !triton_gpu.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A>, !ttg.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> - %b_op = triton_gpu.local_load %b : !triton_gpu.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %next_a = triton_gpu.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !triton_gpu.memdesc<128x32xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %next_b = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !triton_gpu.memdesc<32x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !triton_gpu.memdesc<128x32xf8E5M2, #A>, !triton_gpu.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A>, !ttg.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } @@ -75,20 +75,20 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr // matmul: 128x16 @ 16x128 -> 128x128 // CHECK: tt.func @matmul_loop_mixed // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.local_load %[[B0_PREFETCH_SMEM]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[a0_prefetch]], %[[b0_prefetch]], {{.*}} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] -module attributes { "triton_gpu.num-warps" = 4 : i32 } { +module attributes { "ttg.num-warps" = 4 : i32 } { tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x16x!tt.ptr, #AL> %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<16x128x!tt.ptr, #BL> @@ -103,24 +103,24 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %b_off = arith.constant dense<4> : tensor<16x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> - %a_init = triton_gpu.local_alloc %a_ : (tensor<128x16xf8E5M2, #AL>) -> !triton_gpu.memdesc<128x16xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> - %b_init = triton_gpu.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !triton_gpu.memdesc<16x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !triton_gpu.memdesc<128x16xf8E5M2, #A>, !triton_gpu.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = triton_gpu.local_load %a : !triton_gpu.memdesc<128x16xf8E5M2, #A> -> tensor<128x16xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A>, !ttg.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x16xf8E5M2, #A> -> tensor<128x16xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x16xf8E5M2, #A_OP> -> tensor<128x16xf16, #A_OP> - %b_op = triton_gpu.local_load %b : !triton_gpu.memdesc<16x128xf16, #B> -> tensor<16x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<16x128xf16, #B> -> tensor<16x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x16xf16, #A_OP> * tensor<16x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x16x!tt.ptr, #AL>, tensor<128x16xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<16x128x!tt.ptr, #BL>, tensor<16x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> - %next_a = triton_gpu.local_alloc %next_a_ : (tensor<128x16xf8E5M2, #AL>) -> !triton_gpu.memdesc<128x16xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> - %next_b = triton_gpu.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !triton_gpu.memdesc<16x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !triton_gpu.memdesc<128x16xf8E5M2, #A>, !triton_gpu.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A>, !ttg.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } @@ -132,10 +132,10 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr // CHECK: scf.if // CHECK: tt.store // CHECK-NOT: scf.yield -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:86", "ttg.threads-per-warp" = 32 : i32} { tt.func @matmul_loop_yield_no_operand(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %c32_i32 = arith.constant 32 : i32 @@ -155,19 +155,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %10 = arith.remsi %9, %2 : i32 %11 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> %12 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %14 = triton_gpu.convert_layout %11 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %15 = tt.dot %13, %14, %cst, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %14 = ttg.convert_layout %11 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %15 = tt.dot %13, %14, %cst, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> %16 = arith.cmpi sgt, %10, %c0_i32 : i32 %17 = scf.if %16 -> (tensor<32x32xf32, #mma>) { - %21 = tt.dot %13, %14, %15, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %21 = tt.dot %13, %14, %15, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> scf.yield %21 : tensor<32x32xf32, #mma> } else { scf.yield %15 : tensor<32x32xf32, #mma> } %18 = tt.splat %arg5 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> %19 = arith.truncf %17 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> - %20 = triton_gpu.convert_layout %19 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked1> + %20 = ttg.convert_layout %19 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked1> tt.store %18, %20 : tensor<32x32x!tt.ptr, #blocked1> } tt.return @@ -176,37 +176,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = false}> -#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = false}> +#A_OP = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_OP = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> // CHECK: tt.func @matmul_loop_mixed_amd // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.local_load %[[B0_PREFETCH_SMEM]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] -// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.local_load %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] // CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] -// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] -// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.local_load %[[B_REM_SMEM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] -module attributes { "triton_gpu.num-warps" = 4 : i32 } { +module attributes { "ttg.num-warps" = 4 : i32 } { tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> @@ -221,24 +221,24 @@ tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a_init = triton_gpu.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !triton_gpu.memdesc<128x32xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b_init = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !triton_gpu.memdesc<32x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !triton_gpu.memdesc<128x32xf8E5M2, #A>, !triton_gpu.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = triton_gpu.local_load %a : !triton_gpu.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A>, !ttg.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> - %b_op = triton_gpu.local_load %b : !triton_gpu.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %next_a = triton_gpu.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !triton_gpu.memdesc<128x32xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %next_b = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !triton_gpu.memdesc<32x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !triton_gpu.memdesc<128x32xf8E5M2, #A>, !triton_gpu.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A>, !ttg.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir index 67bf5bdbcc..bbb0de2ad1 100644 --- a/test/TritonGPU/reduce-data-duplication.mlir +++ b/test/TritonGPU/reduce-data-duplication.mlir @@ -1,14 +1,14 @@ // RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s -// CHECK: #[[$SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} +// CHECK: #[[$SHARED:.*]] = #ttg.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} // CHECK-LABEL: apply_swizzle -// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !triton_gpu.memdesc<16x256xf16, #[[$SHARED]], #triton_gpu.shared_memory> +// CHECK: %{{.*}} = ttg.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !ttg.memdesc<16x256xf16, #[[$SHARED]], #ttg.shared_memory> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @apply_swizzle(%arg0: tensor<16x256xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %0 = ttg.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -16,13 +16,13 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- // CHECK-LABEL: conversion_shortcut_blocked_dotop_warp32 -// CHECK-NOT: triton_gpu.local_alloc -// CHECK: triton_gpu.convert_layout -// CHECK-NOT: triton_gpu.local_alloc -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-NOT: ttg.local_alloc +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.local_alloc +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @conversion_shortcut_blocked_dotop_warp32(%arg0: tensor<64x64xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -30,13 +30,13 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- // CHECK-LABEL: conversion_shortcut_blocked_dotop_warp64 -// CHECK-NOT: triton_gpu.local_alloc -// CHECK: triton_gpu.convert_layout -// CHECK-NOT: triton_gpu.local_alloc -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [0, 1]}> -module attributes {"triton_gpu.target" = "hip:gfx940", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK-NOT: ttg.local_alloc +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.local_alloc +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.target" = "hip:gfx940", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func @conversion_shortcut_blocked_dotop_warp64(%arg0: tensor<64x64xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir index 28f8d385cf..c95a1cedc5 100644 --- a/test/TritonGPU/reorder-instructions.mlir +++ b/test/TritonGPU/reorder-instructions.mlir @@ -2,21 +2,21 @@ // check that we don't hoist convert_layout above its operand definition. // CHECK-LABEL: convert_cannot_hoist -// CHECK: %[[CVTS:.+]] = triton_gpu.local_alloc -// CHECK: triton_gpu.local_load %[[CVTS]] +// CHECK: %[[CVTS:.+]] = ttg.local_alloc +// CHECK: ttg.local_load %[[CVTS]] // CHECK: tt.dot -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @convert_cannot_hoist(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !triton_gpu.memdesc<32x32xf32, #shared> - %11 = triton_gpu.local_load %10 : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %10 = ttg.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> + %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> tt.return } @@ -25,21 +25,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- // CHECK-LABEL: sink_convert_dealloc -// CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.local_dealloc %0 : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> -// CHECK: triton_gpu.local_dealloc %1 : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> -// CHECK: %3 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, mutable> +// CHECK: ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, mutable> +// CHECK: %3 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { - %0 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> - %1 = triton_gpu.local_alloc : () -> !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> - %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> - triton_gpu.async_wait {num = 0 : i32} - triton_gpu.local_dealloc %0 : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> - triton_gpu.local_dealloc %1 : !triton_gpu.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> + %2 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, mutable> %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> tt.return } @@ -48,24 +48,24 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- // CHECK-LABEL: sink_convert_idx_1 -// CHECK: triton_gpu.local_load %{{.*}} : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -// CHECK: triton_gpu.local_load %{{.*}} : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> // CHECK: tt.dot -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_idx_1(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %B = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %BS = triton_gpu.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !triton_gpu.memdesc<32x32xf32, #shared> - %BD = triton_gpu.local_load %BS : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> + %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %A = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS = triton_gpu.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !triton_gpu.memdesc<32x32xf32, #shared> - %AD = triton_gpu.local_load %AS : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %AD, %BD, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %AS = ttg.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> + %AD = ttg.local_load %AS : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %AD, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> tt.return } @@ -75,28 +75,28 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // check that we don't sink convert_layout if it has multi users // CHECK-LABEL: convert_cannot_sink -// CHECK: triton_gpu.local_load %{{.*}} : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -// CHECK: triton_gpu.local_load %{{.*}} : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> // CHECK: tt.dot -// CHECK: triton_gpu.local_load %{{.*}} : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> // CHECK: tt.dot -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @convert_cannot_sink(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %B = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %BS = triton_gpu.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !triton_gpu.memdesc<32x32xf32, #shared> - %BD = triton_gpu.local_load %BS : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> + %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %A0 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS0 = triton_gpu.local_alloc %A0 : (tensor<32x32xf32, #blocked>) -> !triton_gpu.memdesc<32x32xf32, #shared> - %AD0 = triton_gpu.local_load %AS0 : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %AD0, %BD, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %AS0 = ttg.local_alloc %A0 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> + %AD0 = ttg.local_load %AS0 : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %AD0, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %A1 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS1 = triton_gpu.local_alloc %A1 : (tensor<32x32xf32, #blocked>) -> !triton_gpu.memdesc<32x32xf32, #shared> - %AD1 = triton_gpu.local_load %AS1 : !triton_gpu.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %13 = tt.dot %AD1, %BD, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %AS1 = ttg.local_alloc %A1 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> + %AD1 = ttg.local_load %AS1 : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %13 = tt.dot %AD1, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> tt.return } } diff --git a/test/TritonGPU/tritongpu_ops.mlir b/test/TritonGPU/tritongpu_ops.mlir index 3fc0585b12..d3dc5277e1 100644 --- a/test/TritonGPU/tritongpu_ops.mlir +++ b/test/TritonGPU/tritongpu_ops.mlir @@ -1,11 +1,11 @@ // RUN: triton-opt %s | triton-opt | FileCheck %s -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared0 = #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> -module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: memdesc - // CHECK-SAME: !triton_gpu.memdesc<1x64x16xf16, #{{.+}}> - tt.func @memdesc(%d : !triton_gpu.memdesc<1x64x16xf16, #shared0>) { + // CHECK-SAME: !ttg.memdesc<1x64x16xf16, #{{.+}}> + tt.func @memdesc(%d : !ttg.memdesc<1x64x16xf16, #shared0>) { tt.return } } diff --git a/test/TritonGPU/verify-blocked-layout.mlir b/test/TritonGPU/verify-blocked-layout.mlir index ec39b26d10..3c1d016cd5 100644 --- a/test/TritonGPU/verify-blocked-layout.mlir +++ b/test/TritonGPU/verify-blocked-layout.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt --split-input-file %s --verify-diagnostics -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[16, 1], warpsPerCTA=[4, 1], @@ -10,9 +10,9 @@ CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: !tt.ptr) { // expected-error @+1 {{threads per warp}} @@ -23,7 +23,7 @@ module attributes { // ----- -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 2], @@ -33,9 +33,9 @@ module attributes { CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: !tt.ptr) { // expected-error @+1 {{warps per CTA}} @@ -46,7 +46,7 @@ module attributes { // ----- -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], @@ -56,9 +56,9 @@ module attributes { CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: !tt.ptr) { // expected-error @+1 {{CTAs per CGA}} @@ -69,7 +69,7 @@ module attributes { // ----- -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], @@ -79,9 +79,9 @@ module attributes { CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: !tt.ptr) { // Note it's a 3d tensor here, but #blocked is 2D. @@ -93,7 +93,7 @@ module attributes { // ----- -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], @@ -103,9 +103,9 @@ module attributes { CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: tensor<8xf32, #blocked>) { // expected-error @+1 {{rank}} diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 924216222a..2085cbf213 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -1,25 +1,25 @@ // RUN: triton-opt %s -split-input-file --triton-nvidia-tma-lowering --allocate-shared-memory -test-print-membar | FileCheck %s -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: init_barrier // CHECK: local_alloc // CHECK-NEXT: gpu.barrier // CHECK-NEXT: init_barrier tt.func @init_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: inval_barrier // CHECK: local_alloc // CHECK-NEXT: gpu.barrier @@ -28,18 +28,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: inval_barrier tt.func @inval_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.inval_barrier %alloc : !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + ttng.inval_barrier %alloc : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: barrier_expect // CHECK: local_alloc // CHECK-NEXT: gpu.barrier @@ -48,18 +48,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: barrier_expect tt.func @barrier_expect(%pred : i1) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + ttng.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #ttg.shared_memory, mutable> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: wait_barrier // CHECK: local_alloc // CHECK-NEXT: gpu.barrier @@ -68,9 +68,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: wait_barrier tt.func @wait_barrier(%phase : i32) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !triton_gpu.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %alloc, %phase : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + ttng.wait_barrier %alloc, %phase : <1xi64, #shared0, #ttg.shared_memory, mutable> tt.return } } @@ -78,9 +78,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked0> { // CHECK-LABEL: tma_load // CHECK: local_dealloc @@ -89,8 +89,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: gpu.barrier // CHECK-NEXT: init_barrier %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !triton_gpu.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %alloc : !triton_gpu.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #ttg.shared_memory, mutable> + ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #ttg.shared_memory, mutable> %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked0> tt.return %l : tensor<128x64xf16, #blocked0> } @@ -98,18 +98,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_store -// CHECK: triton_gpu.local_alloc -// CHECK-NEXT: triton_gpu.local_dealloc +// CHECK: ttg.local_alloc +// CHECK-NEXT: ttg.local_dealloc // CHECK-NEXT: gpu.barrier -// CHECK-NEXT: triton_gpu.local_alloc +// CHECK-NEXT: ttg.local_alloc tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !triton_gpu.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %alloc : !triton_gpu.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #ttg.shared_memory, mutable> + ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #ttg.shared_memory, mutable> tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked0> tt.return } diff --git a/test/TritonNvidiaGPU/tma_lowering.mlir b/test/TritonNvidiaGPU/tma_lowering.mlir index dc8113ca8a..dbde678e55 100644 --- a/test/TritonNvidiaGPU/tma_lowering.mlir +++ b/test/TritonNvidiaGPU/tma_lowering.mlir @@ -1,16 +1,16 @@ // RUN: triton-opt %s -split-input-file --triton-nvidia-tma-lowering | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_load -// CHECK: triton_gpu.local_alloc : () -// CHECK: triton_gpu.local_alloc : () -// CHECK: triton_nvidia_gpu.init_barrier -// CHECK: triton_nvidia_gpu.tensor_desc_to_tma_ptr -// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local -// CHECK: triton_nvidia_gpu.wait_barrier -// CHECK: triton_nvidia_gpu.inval_barrier -// CHECK: triton_gpu.local_load +// CHECK: ttg.local_alloc : () +// CHECK: ttg.local_alloc : () +// CHECK: ttng.init_barrier +// CHECK: ttng.tensor_desc_to_tma_ptr +// CHECK: ttng.async_tma_copy_global_to_local +// CHECK: ttng.wait_barrier +// CHECK: ttng.inval_barrier +// CHECK: ttg.local_load tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked> { %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> tt.return %l : tensor<128x64xf16, #blocked> @@ -19,13 +19,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_store -// CHECK: triton_gpu.local_alloc -// CHECK: triton_nvidia_gpu.fence_async_shared {bCluster = false} -// CHECK: triton_nvidia_gpu.tensor_desc_to_tma_ptr -// CHECK: triton_nvidia_gpu.async_tma_copy_local_to_global +// CHECK: ttg.local_alloc +// CHECK: ttng.fence_async_shared {bCluster = false} +// CHECK: ttng.tensor_desc_to_tma_ptr +// CHECK: ttng.async_tma_copy_local_to_global tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked>) { tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked> tt.return @@ -34,10 +34,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: make_tensor_descriptor // CHECK: %0 = arith.extsi %arg2 : i32 to i64 - // CHECK: %1 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + // CHECK: %1 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr // CHECK: %2 = arith.shrsi %0, %c4_i64 : i64 // CHECK: tt.experimental_tensormap_create %1, %arg0, [%c32_i32, %c8_i32], [%arg2, %arg1], [%2], [%c1_i32, %c1_i32] {elem_type = 0 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 1 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () // CHECK: tt.experimental_tensormap_fenceproxy_acquire %1 : !tt.ptr diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 0029c76ec3..81b07f2e7d 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -339,7 +339,7 @@ def make_llir(src, metadata, options): llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion) # Get some metadata - metadata["shared"] = src.get_int_attr("triton_gpu.shared") + metadata["shared"] = src.get_int_attr("ttg.shared") amd.cleanup_bitcode_metadata(llvm_mod) # Disable inlining of print related functions, diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 0b865cd1b8..b2f857e40a 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -65,11 +65,11 @@ def ExtractSliceOp Example 1: ```mlir - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], + #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [0, 1]}> - #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], + #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> - %1 = triton_gpu.convert_layout %0 : tensor<128x128xf16, #blocked> + %1 = ttg.convert_layout %0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked1> // create a slice of base tensor %1 with static offsets %2 = amdgpu.extract_slice %0 [0, 0] : diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index eec62e3503..502e4ca4a6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -95,9 +95,9 @@ struct ConvertTritonAMDGPUToLLVM int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); // Hack: WSMaterialization may have changed the effective number of warps, - // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to + // in a way that isn't reflected in ttg.num-warps. If so, we have to // respect that here. - if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { + if (Attribute attr = mod->getAttr("ttg.num-warp-groups-per-cta")) { numWarps *= cast(attr).getInt(); } diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py index bca3603840..a9c7df4754 100644 --- a/third_party/amd/python/test/test_extract_slice.py +++ b/third_party/amd/python/test/test_extract_slice.py @@ -11,7 +11,7 @@ num_ctas_list = [1] -GPU_DIALECT = "triton_gpu" +GPU_DIALECT = "ttg" if is_hip(): THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size @@ -67,33 +67,33 @@ def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_t ir = f""" #blocked = {blocked_layout} #extract_layout = {extract_layout} - module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {str(64)} : i32}} {{ + module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> - %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> - %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> - %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> - %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> - %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> - %12 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> + %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> - %14 = triton_gpu.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> tt.return diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6d6d70fc87..f5970f5956 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -252,7 +252,7 @@ def make_llir(src, metadata, options, capability): ptx_version = get_ptx_version_from_options(options) # warp-specialization mutates num_warps - num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + num_warp_groups = src.get_int_attr("ttg.num-warp-groups-per-cta") if num_warp_groups is not None: metadata["num_warps"] *= num_warp_groups mod = src @@ -303,9 +303,9 @@ def make_llir(src, metadata, options, capability): llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) # Get some metadata - metadata["shared"] = src.get_int_attr("triton_gpu.shared") - metadata["global_scratch_size"] = src.get_int_attr("triton_gpu.global_scratch_memory_size") - metadata["global_scratch_align"] = src.get_int_attr("triton_gpu.global_scratch_memory_alignment") + metadata["shared"] = src.get_int_attr("ttg.shared") + metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size") + metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment") ret = str(llvm_mod) del llvm_mod del context diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 779bc1b788..7a34955a96 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -43,7 +43,7 @@ std::string strReplace(std::string s, const std::string &from, // We use some abbreviations when spelling out MLIR types. std::string expandTyStr(std::string s) { s = strReplace(s, "T<", "tensor<"); - s = strReplace(s, "#B", "#triton_gpu.blocked"); + s = strReplace(s, "#B", "#ttg.blocked"); s = strReplace(s, "spt", "sizePerThread"); s = strReplace(s, "tpw", "threadsPerWarp"); s = strReplace(s, "wpc", "warpsPerCTA");