diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 81045776537438..e9cad1d6ecfa01 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -397,6 +397,27 @@ diff --ruN a/stablehlo/stablehlo/tests/interpret/chlo/ragged_dot.mlir b/stablehl + ]> : tensor<2x11x7xf32> + func.return +} +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir ++++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +@@ -1940,6 +1940,17 @@ + return %1 : tensor<12xi64> + } + ++// ----- ++ ++// CHECK-LABEL: @reorder_invalid_with_dynamic_shape ++func.func @reorder_invalid_with_dynamic_shape(%arg0: tensor<1x3x4xf32>) -> (tensor) { ++ // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> ++ // CHECK-NEXT: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<3x4xf32>) -> tensor ++ // CHECK: return %[[CONVERT]] ++ %0 = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> ++ %1 = stablehlo.convert %0 : (tensor<3x4xf32>) -> tensor ++ return %1 : tensor ++} + + // ----- + diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir --- stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir +++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir @@ -728,7 +749,27 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cp #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -@@ -1470,6 +1473,9 @@ +@@ -1447,12 +1450,18 @@ + return rewriter.notifyMatchFailure( + op, "defining operation of unexpected type"); + ++ // Reshape and broadcast are not allowed to have dynamic shape. ++ Value result = op->getResult(0); ++ if (isa(definingOp) && ++ !cast(result.getType()).hasStaticShape()) ++ return rewriter.notifyMatchFailure( ++ op, "cannot reorder around reshape/broadcast with dynamic shape"); ++ + // Only reorder if the defining op has no other uses. + if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) + return rewriter.notifyMatchFailure(op, "operation has more than one use"); + + Value input = definingOp->getOperand(0); +- Value result = op->getResult(0); + auto intermediateType = cast(input.getType()) + .clone(getElementTypeOrSelf(result.getType())); + +@@ -1470,6 +1479,9 @@ struct StablehloAggressiveSimplificationPass final : impl::StablehloAggressiveSimplificationPassBase< StablehloAggressiveSimplificationPass> { @@ -738,7 +779,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cp LogicalResult initialize(MLIRContext *context) override { RewritePatternSet patterns_(context); populateStablehloCanonicalizationPatterns(context, &patterns_); -@@ -1478,11 +1484,12 @@ +@@ -1478,11 +1490,12 @@ } void runOnOperation() override { @@ -752,7 +793,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cp FrozenRewritePatternSet patterns; }; -@@ -1515,5 +1522,10 @@ +@@ -1515,5 +1528,10 @@ DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context); }