Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Vector] Replace vector.transpose with vector.shape_cast #125966

Closed

Conversation

ita9naiwa
Copy link
Contributor

Suppose the permutation width is defined as the last index in the
permutation array that is not equal to its index. This pattern is
applied to transpose operations where the input vector has a shape with
at most one non-unit dimension up to the permutation width. The pattern
replaces the transpose operation with a shape cast operation.
For example:
%0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to
vector<1x4x1xi32>
is replaced by
%0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
given the permutation width is 2.

this work(#94912) is credited @pashu123
I added some tests and fixed tests but didn't know how to push on his branch and created new PR.

pashu123 and others added 2 commits February 5, 2025 21:37
Suppose the permutation width is defined as the last index in the
permutation array that is not equal to its index. This pattern is
applied to transpose operations where the input vector has a shape with
at most one non-unit dimension up to the permutation width. The pattern
replaces the transpose operation with a shape cast operation.
For example:
%0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to
vector<1x4x1xi32>
is replaced by
 %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
given the permutation width is 2.
@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Hyunsung Lee (ita9naiwa)

Changes

> Suppose the permutation width is defined as the last index in the
permutation array that is not equal to its index. This pattern is
applied to transpose operations where the input vector has a shape with
at most one non-unit dimension up to the permutation width. The pattern
replaces the transpose operation with a shape cast operation.
For example:
%0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to
vector<1x4x1xi32>
is replaced by
%0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
given the permutation width is 2.

this work(#94912) is credited @pashu123
I added some tests and fixed tests but didn't know how to push on his branch and created new PR.


Full diff: https://github.com/llvm/llvm-project/pull/125966.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+56-2)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+17-35)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..a29ba47b28cde15 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -450,6 +450,59 @@ class Transpose2DWithUnitDimToShapeCast
   }
 };
 
+// Suppose the permutation width is defined as the last index in the permutation
+// array that is not equal to its index. This pattern is applied to transpose
+// operations where the input vector has a shape with at most one non-unit
+// dimension up to the permutation width. The pattern replaces the transpose
+// operation with a shape cast operation.
+// For example:
+//  %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32>
+//  is replaced by
+//  %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32>
+//  given the permutation width is 2.
+class TransposeWithUnitDimToShapeCast
+    : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  TransposeWithUnitDimToShapeCast(MLIRContext *context,
+                                  PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.getVector();
+    VectorType inputType = op.getSourceVectorType();
+    if (inputType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "This lowering does not support scalable vectors");
+    VectorType resType = op.getResultVectorType();
+
+    ArrayRef<int64_t> transp = op.getPermutation();
+
+    // Get the permutation width.
+    int64_t permWidth = 1;
+    for (auto &&[idx, val] : llvm::enumerate(transp)) {
+      if (static_cast<int64_t>(idx) != val)
+        permWidth = idx + 1;
+    }
+
+    // Check the no. of non unit dim in the input shape upto permutation width
+    // is not greater than one.
+    auto inputShape = inputType.getShape();
+
+    int64_t countNonUnitDims = 0;
+    for (int i = 0; i < permWidth; i++) {
+      if (inputShape[i] != 1)
+        countNonUnitDims++;
+      if (countNonUnitDims > 1)
+        return failure();
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+    return success();
+  }
+};
+
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -522,8 +575,9 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns, VectorTransformsOptions options,
     PatternBenefit benefit) {
-  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
-                                                  benefit);
+  patterns
+      .add<Transpose2DWithUnitDimToShapeCast, TransposeWithUnitDimToShapeCast>(
+          patterns.getContext(), benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       options, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 219a72df52a19c9..68e408488cf06f0 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -23,53 +23,21 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
 
 // CHECK-LABEL: func @transpose102_1x8x8xf32
 func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // CHECK: %0 = vector.shape_cast %arg0 : vector<1x8x8xf32> to vector<8x1x8xf32>
   %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32>
   return %0 : vector<8x1x8xf32>
 }
 
 // CHECK-LABEL: func @transpose102_8x1x8xf32
 func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32>
+  // CHECK: %0 = vector.shape_cast %arg0 : vector<8x1x8xf32> to vector<1x8x8xf32>
   %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32>
   return %0 : vector<1x8x8xf32>
 }
 
 // CHECK-LABEL:   func @transpose1023_1x1x8x8xf32(
 func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
-  // Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
+  // CHECK: return %arg0 : vector<1x1x8x8xf32>
   %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
   return %0 : vector<1x1x8x8xf32>
 }
@@ -386,6 +354,20 @@ func.func @transpose10_4x1xf32_scalable(%arg0: vector<4x[1]xf32>) -> vector<[1]x
   return %0 : vector<[1]x4xf32>
 }
 
+// CHECK-LABEL: func @transpose_nd1
+func.func @transpose_nd1(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> {
+  // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
+  %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x2x1x16xf32> to vector<1x1x2x16xf32>
+  return %0 : vector<1x1x2x16xf32>
+}
+
+// CHECK-LABEL: func @transpose_nd2
+func.func @transpose_nd2(%arg0: vector<1x1x2x16xf32>) -> vector<1x2x1x16xf32> {
+  // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x1x2x16xf32> to vector<1x2x1x16xf32>
+  %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x1x2x16xf32> to vector<1x2x1x16xf32>
+  return %0 : vector<1x2x1x16xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

@dcaballe
Copy link
Contributor

dcaballe commented Feb 6, 2025

this work(#94912) is credited @pashu123
I added some tests and fixed tests but didn't know how to push on his branch and created new PR.

You can change the commit's author with:

git commit --amend --author="Author Name <[email protected]>" --no-edit

Or add multiple authors with:

https://docs.github.com/en/pull-requests/committing-changes-to-your-project/creating-and-editing-commits/creating-a-commit-with-multiple-authors

@dcaballe
Copy link
Contributor

dcaballe commented Feb 6, 2025

Didn't we have people with strong oppinions about the transpose -> shape_cast canonicalization because it's supposed to drop some information? Not sure I remember the details... @hanhanW, @MaheshRavishankar?

@MaheshRavishankar
Copy link
Contributor

Thanks @dcaballe for the ping. I think the issue was mainly w.r.t having this as a canonicalization. As a vector lowering pattern this might be OK. One thing to consider is that this is lowering to a shape_cast, but there is also a populateVectorShapeCastLoweringPatterns . Are the shape casts generated here handled by the shape cast lowering. If so this change seems OK to me.

cc @kuhar

@hanhanW
Copy link
Contributor

hanhanW commented Feb 7, 2025

The author is trying to address one of my old issue, which propose to drop unit dims for transpose ops if it is a nop. iree-org/iree#17593

I revisited the issue because it was quite old, and I found that we already have such a pattern to handle the transpose op. The no-ops get folded away as expected. This kind of pattern is categorized to dropXXXUnitDim, and we already have such support. It is landed as da8778e, so I think we no longer need the PR.

(Sorry @ita9naiwa , I should revisit it first.)

@ita9naiwa ita9naiwa closed this Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants