Skip to content

Commit

Permalink
Fix biasScaleShape of GroupNormalizationV21 to support ranks > 4 (#3030)
Browse files Browse the repository at this point in the history
Signed-off-by: Rickert, Jonas <[email protected]>
  • Loading branch information
jorickert authored Jan 2, 2025
1 parent f6d0806 commit c0d447f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,8 @@ LogicalResult ONNXGroupNormalizationCommon(

// Calculate the (possible) dynamic dimensions for biasScaleShape
Value NGShape = create.onnx.constantInt64({numGroups});
Value oneDimShape = create.onnx.constantInt64({1, 1});
Value oneDimShape =
create.onnx.constantInt64(SmallVector<int64_t>(spacialRank, 1));
Type biasScaleShapeType =
RankedTensorType::get({inputRank}, rewriter.getI64Type());
Value biasScaleShape = create.onnx.concat(
Expand Down
28 changes: 14 additions & 14 deletions test/mlir/onnx/onnx_decompose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -565,24 +565,24 @@ func.func @test_groupnorm_v18(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>,
}
// -----

func.func @test_groupnorm_v21(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x2x2xf32> {
%0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x2x2xf32>
func.func @test_groupnorm_v21(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<3x4x2x2xf32> {
%0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<3x4x2x2xf32>
onnx.Return %0 : tensor<3x4x2x2xf32>
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_groupnorm_v21
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x2x2xf32>, [[PARAM_1_:%.+]]: tensor<2xf32>, [[PARAM_2_:%.+]]: tensor<2xf32>) -> tensor<3x4x2x2xf32> {
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x2x2xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>, [[PARAM_2_:%.+]]: tensor<4xf32>) -> tensor<3x4x2x2xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2> : tensor<1xi64>
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<2xi64>
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<2xi64>) -> tensor<4xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x2x1x1xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x2x1x1xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<4xf32>, tensor<4xi64>) -> tensor<2x2x1x1xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<4xf32>, tensor<4xi64>) -> tensor<2x2x1x1xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64>
// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 2 : si64} : (tensor<3x4x2x2xf32>) -> tensor<2xi64>
// CHECK: [[VAR_8_:%.+]] = "onnx.Concat"([[VAR_5_]], [[VAR_6_]], [[VAR_7_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5xi64>
// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<3x4x2x2xf32>, tensor<5xi64>) -> tensor<3x2x2x2x2xf32>
// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[Mean_]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_9_]], [[VAR_3_]], [[VAR_4_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x2x2xf32>, tensor<2x2x1x1xf32>, tensor<2x2x1x1xf32>) -> (tensor<3x2x2x2x2xf32>, none, none)
// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_9_]], [[VAR_3_]], [[VAR_4_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x2x2xf32>, tensor<2x2x1x1xf32>, tensor<2x2x1x1xf32>) -> (tensor<3x2x2x2x2xf32>, none, none)
// CHECK: [[VAR_11_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<4xi64>
// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[Y_]], [[VAR_11_]]) {allowzero = 0 : si64} : (tensor<3x2x2x2x2xf32>, tensor<4xi64>) -> tensor<3x4x2x2xf32>
// CHECK: onnx.Return [[VAR_12_]] : tensor<3x4x2x2xf32>
Expand Down Expand Up @@ -614,24 +614,24 @@ func.func @group_norm5d_v18(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<2xf32>,

// -----

func.func @group_norm5d_v21(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> {
%0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x6x8x16xf32>
func.func @group_norm5d_v21(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<3x4x6x8x16xf32> {
%0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<3x4x6x8x16xf32>
onnx.Return %0 : tensor<3x4x6x8x16xf32>
// mlir2FileCheck.py
// CHECK-LABEL: func.func @group_norm5d_v21
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x6x8x16xf32>, [[PARAM_1_:%.+]]: tensor<2xf32>, [[PARAM_2_:%.+]]: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> {
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x6x8x16xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>, [[PARAM_2_:%.+]]: tensor<4xf32>) -> tensor<3x4x6x8x16xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2> : tensor<1xi64>
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<2xi64>
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<2xi64>) -> tensor<5xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<2xf32>, tensor<5xi64>) -> tensor<2x2x1x1x1xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<2xf32>, tensor<5xi64>) -> tensor<2x2x1x1x1xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<3xi64>
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<3xi64>) -> tensor<5xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<4xf32>, tensor<5xi64>) -> tensor<2x2x1x1x1xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<4xf32>, tensor<5xi64>) -> tensor<2x2x1x1x1xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64>
// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 2 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<3xi64>
// CHECK: [[VAR_8_:%.+]] = "onnx.Concat"([[VAR_5_]], [[VAR_6_]], [[VAR_7_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<3xi64>) -> tensor<6xi64>
// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<3x4x6x8x16xf32>, tensor<6xi64>) -> tensor<3x2x2x6x8x16xf32>
// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_]], [[Mean_]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_9_]], [[VAR_3_]], [[VAR_4_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<2x2x1x1x1xf32>, tensor<2x2x1x1x1xf32>) -> (tensor<3x2x2x6x8x16xf32>, none, none)
// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_9_]], [[VAR_3_]], [[VAR_4_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<2x2x1x1x1xf32>, tensor<2x2x1x1x1xf32>) -> (tensor<3x2x2x6x8x16xf32>, none, none)
// CHECK: [[VAR_11_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<5xi64>
// CHECK: [[VAR_12_:%.+]] = "onnx.Reshape"([[Y_]], [[VAR_11_]]) {allowzero = 0 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<5xi64>) -> tensor<3x4x6x8x16xf32>
// CHECK: onnx.Return [[VAR_12_]] : tensor<3x4x6x8x16xf32>
Expand Down

0 comments on commit c0d447f

Please sign in to comment.