Skip to content

Commit

Permalink
Merge branch 'Samsung:master' into rope/luci/lang
Browse files Browse the repository at this point in the history
  • Loading branch information
ys44kim authored Sep 24, 2024
2 parents 619a733 + 8958dc0 commit 4c09892
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
10 changes: 8 additions & 2 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,14 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
bool is_static_shape = true;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
{
if (input_shape.dim(i).known())
input_element_count *= input_shape.dim(i).value();
else
is_static_shape = false;
}
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
Expand All @@ -153,7 +159,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
output_element_count *= dim_value;
}
}
if (unknown_dim_index != UINT32_MAX)
if (unknown_dim_index != UINT32_MAX && is_static_shape)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}
Expand Down
32 changes: 32 additions & 0 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,35 @@ TEST(ShapeRuleTest, reshape_by_input_const_dynamic)
ASSERT_EQ(6, output_shape.dim(0).value());
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_should_infer)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({0, 3, 4});
tensor_input->dim(0).unset();
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = -1;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_FALSE(output_shape.dim(0).known());
ASSERT_TRUE(output_shape.dim(1).known());
ASSERT_EQ(4, output_shape.dim(1).value());
}

0 comments on commit 4c09892

Please sign in to comment.