Skip to content

Commit

Permalink
[StableHLO] Bugfix to disable reorder around reshape/broadcast with d…
Browse files Browse the repository at this point in the history
…ynamic shape

PiperOrigin-RevId: 719322804
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Jan 24, 2025
1 parent 70ce106 commit ed34ba7
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,27 @@ diff --ruN a/stablehlo/stablehlo/tests/interpret/chlo/ragged_dot.mlir b/stablehl
+ ]> : tensor<2x11x7xf32>
+ func.return
+}
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
@@ -1940,6 +1940,17 @@
return %1 : tensor<12xi64>
}

+// -----
+
+// CHECK-LABEL: @reorder_invalid_with_dynamic_shape
+func.func @reorder_invalid_with_dynamic_shape(%arg0: tensor<1x3x4xf32>) -> (tensor<?x4xf32>) {
+ // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32>
+ // CHECK-NEXT: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<3x4xf32>) -> tensor<?x4xf32>
+ // CHECK: return %[[CONVERT]]
+ %0 = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32>
+ %1 = stablehlo.convert %0 : (tensor<3x4xf32>) -> tensor<?x4xf32>
+ return %1 : tensor<?x4xf32>
+}

// -----

diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir
Expand Down Expand Up @@ -728,7 +749,27 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cp
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -1470,6 +1473,9 @@
@@ -1447,12 +1450,18 @@
return rewriter.notifyMatchFailure(
op, "defining operation of unexpected type");

+ // Reshape and broadcast are not allowed to have dynamic shape.
+ Value result = op->getResult(0);
+ if (isa<ReshapeOp, BroadcastOp>(definingOp) &&
+ !cast<ShapedType>(result.getType()).hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "cannot reorder around reshape/broadcast with dynamic shape");
+
// Only reorder if the defining op has no other uses.
if (!llvm::hasSingleElement(definingOp->getResult(0).getUses()))
return rewriter.notifyMatchFailure(op, "operation has more than one use");

Value input = definingOp->getOperand(0);
- Value result = op->getResult(0);
auto intermediateType = cast<ShapedType>(input.getType())
.clone(getElementTypeOrSelf(result.getType()));

@@ -1470,6 +1479,9 @@
struct StablehloAggressiveSimplificationPass final
: impl::StablehloAggressiveSimplificationPassBase<
StablehloAggressiveSimplificationPass> {
Expand All @@ -738,7 +779,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cp
LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet patterns_(context);
populateStablehloCanonicalizationPatterns(context, &patterns_);
@@ -1478,11 +1484,12 @@
@@ -1478,11 +1490,12 @@
}

void runOnOperation() override {
Expand All @@ -752,7 +793,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cp
FrozenRewritePatternSet patterns;
};

@@ -1515,5 +1522,10 @@
@@ -1515,5 +1528,10 @@
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);
}

Expand Down

0 comments on commit ed34ba7

Please sign in to comment.