diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index fcd5cdee20..fd2fab7c07 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -421,7 +421,10 @@ class DecomposeScaledBlocked : public OpRewritePattern { if (!scale) return v; - return rewriter.create(v.getLoc(), v, scale, elemType); + auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType( + v, elemType, Builder(v.getContext()).getBF16Type()); + return rewriter.create(v.getLoc(), retTy, v, scale, + elemType); } };