-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Vector] Replace vector.transpose with vector.shape_cast #125966
[MLIR][Vector] Replace vector.transpose with vector.shape_cast #125966
Conversation
Suppose the permutation width is defined as the last index in the permutation array that is not equal to its index. This pattern is applied to transpose operations where the input vector has a shape with at most one non-unit dimension up to the permutation width. The pattern replaces the transpose operation with a shape cast operation. For example: %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32> is replaced by %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32> given the permutation width is 2.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Hyunsung Lee (ita9naiwa) Changes> Suppose the permutation width is defined as the last index in the this work(#94912) is credited @pashu123 Full diff: https://github.com/llvm/llvm-project/pull/125966.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..a29ba47b28cde15 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -450,6 +450,59 @@ class Transpose2DWithUnitDimToShapeCast
}
};
+// Suppose the permutation width is defined as the last index in the permutation
+// array that is not equal to its index. This pattern is applied to transpose
+// operations where the input vector has a shape with at most one non-unit
+// dimension up to the permutation width. The pattern replaces the transpose
+// operation with a shape cast operation.
+// For example:
+// %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32>
+// is replaced by
+// %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
+// given the permutation width is 2.
+class TransposeWithUnitDimToShapeCast
+ : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ TransposeWithUnitDimToShapeCast(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ Value input = op.getVector();
+ VectorType inputType = op.getSourceVectorType();
+ if (inputType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "This lowering does not support scalable vectors");
+ VectorType resType = op.getResultVectorType();
+
+ ArrayRef<int64_t> transp = op.getPermutation();
+
+ // Get the permutation width.
+ int64_t permWidth = 1;
+ for (auto &&[idx, val] : llvm::enumerate(transp)) {
+ if (static_cast<int64_t>(idx) != val)
+ permWidth = idx + 1;
+ }
+
+ // Check the no. of non unit dim in the input shape upto permutation width
+ // is not greater than one.
+ auto inputShape = inputType.getShape();
+
+ int64_t countNonUnitDims = 0;
+ for (int i = 0; i < permWidth; i++) {
+ if (inputShape[i] != 1)
+ countNonUnitDims++;
+ if (countNonUnitDims > 1)
+ return failure();
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+ return success();
+ }
+};
+
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -522,8 +575,9 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
- patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
- benefit);
+ patterns
+ .add<Transpose2DWithUnitDimToShapeCast, TransposeWithUnitDimToShapeCast>(
+ patterns.getContext(), benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 219a72df52a19c9..68e408488cf06f0 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -23,53 +23,21 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
// CHECK-LABEL: func @transpose102_1x8x8xf32
func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
- // CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32>
+ // CHECK: %0 = vector.shape_cast %arg0 : vector<1x8x8xf32> to vector<8x1x8xf32>
%0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32>
return %0 : vector<8x1x8xf32>
}
// CHECK-LABEL: func @transpose102_8x1x8xf32
func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
- // CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32>
- // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32>
+ // CHECK: %0 = vector.shape_cast %arg0 : vector<8x1x8xf32> to vector<1x8x8xf32>
%0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32>
return %0 : vector<1x8x8xf32>
}
// CHECK-LABEL: func @transpose1023_1x1x8x8xf32(
func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
- // Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
- // CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
+ // CHECK: return %arg0 : vector<1x1x8x8xf32>
%0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
return %0 : vector<1x1x8x8xf32>
}
@@ -386,6 +354,20 @@ func.func @transpose10_4x1xf32_scalable(%arg0: vector<4x[1]xf32>) -> vector<[1]x
return %0 : vector<[1]x4xf32>
}
+// CHECK-LABEL: func @transpose_nd1
+func.func @transpose_nd1(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> {
+ // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
+ %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
+ return %0 : vector<1x1x2x16xf32>
+}
+
+// CHECK-LABEL: func @transpose_nd2
+func.func @transpose_nd2(%arg0: vector<1x1x2x16xf32>) -> vector<1x2x1x16xf32> {
+ // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x1x2x16xf32> to vector<1x2x1x16xf32>
+ %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x1x2x16xf32> to vector<1x2x1x16xf32>
+ return %0 : vector<1x2x1x16xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
|
You can change the commit's author with:
Or add multiple authors with: |
Didn't we have people with strong oppinions about the transpose -> shape_cast canonicalization because it's supposed to drop some information? Not sure I remember the details... @hanhanW, @MaheshRavishankar? |
Thanks @dcaballe for the ping. I think the issue was mainly w.r.t having this as a canonicalization. As a vector lowering pattern this might be OK. One thing to consider is that this is lowering to a shape_cast, but there is also a cc @kuhar |
The author is trying to address one of my old issue, which propose to drop unit dims for transpose ops if it is a nop. iree-org/iree#17593 I revisited the issue because it was quite old, and I found that we already have such a pattern to handle the transpose op. The no-ops get folded away as expected. This kind of pattern is categorized to dropXXXUnitDim, and we already have such support. It is landed as da8778e, so I think we no longer need the PR. (Sorry @ita9naiwa , I should revisit it first.) |
this work(#94912) is credited @pashu123
I added some tests and fixed tests but didn't know how to push on his branch and created new PR.