Skip to content

Commit

Permalink
[luci/service] Support Range op Shape Inference for Non-const Param
Browse files Browse the repository at this point in the history
This commit support Range op shape inference for non-const param.

ONE-DCO-1.0-Signed-off-by: bokyeong lee <[email protected]>
  • Loading branch information
kyeong8139 committed Sep 13, 2024
1 parent 03bacfe commit d8f8877
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 15 deletions.
16 changes: 1 addition & 15 deletions compiler/luci/service/src/Nodes/CircleRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleRange *node)

if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
{
// We use shape from the node itself
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
{
// TODO remove this copy from `use_own(node);`
// Shape inference rules in this file did not consider unknown dimension.
// If some node has unknown dimension, 0 is inserted and wrong shape
// inference was done as a result.
// To fix this, new shape inference algorithm is being implemented.
// Until new inference algorithm is fully implemented, unknown dimension
// would be represented as 1 along with TFLite expression.
shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
}
return shape;
return output_shape;
}

double start = 0, limit = 0, delta = 0;
Expand Down
53 changes: 53 additions & 0 deletions compiler/luci/service/src/Nodes/CircleRange.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,56 @@ TEST(ShapeRuleTest, range_zero_delta_NEG)

ASSERT_ANY_THROW(shape_inf_rule.infer(&range, shape));
}

TEST(ShapeRuleTest, range_non_const_param)
{
luci::CircleInput start, limit, delta;
luci::CircleRange range;

start.dtype(loco::DataType::S32);
start.shape({1});
start.shape_status(luci::ShapeStatus::VALID);

limit.dtype(loco::DataType::S32);
limit.shape({1});
limit.shape_status(luci::ShapeStatus::VALID);

delta.dtype(loco::DataType::S32);
delta.shape({1});
delta.shape_status(luci::ShapeStatus::VALID);

range.start(&start);
range.limit(&limit);
range.delta(&delta);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&range, shape));
ASSERT_EQ(1, shape.rank());
ASSERT_FALSE(shape.dim(0).known());
ASSERT_EQ(0, shape.dim(0).value());
}

TEST(ShapeRuleTest, range_nullptr_start_NEG)
{
luci::CircleInput limit, delta;
luci::CircleRange range;

limit.dtype(loco::DataType::S32);
limit.shape({1});
limit.shape_status(luci::ShapeStatus::VALID);

delta.dtype(loco::DataType::S32);
delta.shape({1});
delta.shape_status(luci::ShapeStatus::VALID);

range.start(nullptr);
range.limit(&limit);
range.delta(&delta);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&range, shape));
}

0 comments on commit d8f8877

Please sign in to comment.