From 57c3dc672a1bd4947ebe5e5184eb6c6420a876aa Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 5 Dec 2024 17:27:30 +0000 Subject: [PATCH 1/2] Reduce diff with upstream code Signed-off-by: Tiotto, Ettore --- bin/triton-tensor-layout.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index 49a07681ed..7c635dafaa 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -80,9 +80,17 @@ static cl::opt TensorStr( //===--------------------------------------------------------------------===// LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { + StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace(); + // Dispatch to the corresponding dialect helper function to print the layout. - os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); - return success(); + if (dialectName == "ttg") { + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); + } + + llvm::errs() << "Unsupported tensor layout attribute: " + << tensorType.getEncoding() << "\n"; + return failure(); } LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, From d3935c9533ba9cc7cd04e3a0468fdfe528f4603c Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 5 Dec 2024 18:07:59 +0000 Subject: [PATCH 2/2] Reduce diff with upstream code Signed-off-by: Tiotto, Ettore --- .../TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 1568341deb..d5afb6e2b1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -24,8 +24,8 @@ void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) { int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); module.walk([&](triton::SplatOp splatOp) -> void { auto dstType = cast(splatOp.getType()); - auto shared = dyn_cast_or_null( - dstType.getEncoding()); + auto shared = + dyn_cast(dstType.getEncoding()); if (shared) { OpBuilder builder(splatOp); SmallVector sizePerThread(dstType.getRank(), 1);