diff --git a/compiler/luci-interpreter/src/kernels/Transpose.cpp b/compiler/luci-interpreter/src/kernels/Transpose.cpp index 725a9523c27..42a064d5ad3 100644 --- a/compiler/luci-interpreter/src/kernels/Transpose.cpp +++ b/compiler/luci-interpreter/src/kernels/Transpose.cpp @@ -39,7 +39,7 @@ void Transpose::configure() int dims = input()->shape().num_dims(); const int32_t *perm_data = getTensorData(perm()); - assert(input()->shape().num_dims() <= 4); + assert(input()->shape().num_dims() <= 5); assert(input()->element_type() == output()->element_type()); assert(perm()->shape().num_dims() == 1); diff --git a/compiler/luci-interpreter/src/kernels/Transpose.test.cpp b/compiler/luci-interpreter/src/kernels/Transpose.test.cpp index a41db63b1c2..2bcc969526c 100644 --- a/compiler/luci-interpreter/src/kernels/Transpose.test.cpp +++ b/compiler/luci-interpreter/src/kernels/Transpose.test.cpp @@ -110,6 +110,19 @@ TYPED_TEST(TransposeTest, Large2D) 70, 82, 94, 106, 118, 11, 23, 35, 47, 59, 71, 83, 95, 107, 119}); } +TYPED_TEST(TransposeTest, Small5D) +{ + Check( + /*input_shape=*/{2, 3, 4, 1, 2}, /*perm_shape=*/{5}, /*output_shape=*/{1, 3, 4, 2, 2}, + /*input_data=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, + /*perm_data=*/{3, 1, 2, 0, 4}, + /*output_data=*/{0, 1, 24, 25, 2, 3, 26, 27, 4, 5, 28, 29, 6, 7, 30, 31, + 8, 9, 32, 33, 10, 11, 34, 35, 12, 13, 36, 37, 14, 15, 38, 39, + 16, 17, 40, 41, 18, 19, 42, 43, 20, 21, 44, 45, 22, 23, 46, 47}); +} + } // namespace } // namespace kernels } // namespace luci_interpreter