diff --git a/compiler/luci/service/src/Nodes/CircleRange.cpp b/compiler/luci/service/src/Nodes/CircleRange.cpp index 87ad6f59aeb..1b40bb6af92 100644 --- a/compiler/luci/service/src/Nodes/CircleRange.cpp +++ b/compiler/luci/service/src/Nodes/CircleRange.cpp @@ -43,21 +43,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleRange *node) if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr) { - // We use shape from the node itself - loco::TensorShape shape; - shape.rank(node->rank()); - for (uint32_t r = 0; r < node->rank(); ++r) - { - // TODO remove this copy from `use_own(node);` - // Shape inference rules in this file did not consider unknown dimension. - // If some node has unknown dimension, 0 is inserted and wrong shape - // inference was done as a result. - // To fix this, new shape inference algorithm is being implemented. - // Until new inference algorithm is fully implemented, unknown dimension - // would be represented as 1 along with TFLite expression. - shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; - } - return shape; + return output_shape; } double start = 0, limit = 0, delta = 0; diff --git a/compiler/luci/service/src/Nodes/CircleRange.test.cpp b/compiler/luci/service/src/Nodes/CircleRange.test.cpp index b67d287d1ab..3de3f05a4e0 100644 --- a/compiler/luci/service/src/Nodes/CircleRange.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleRange.test.cpp @@ -95,3 +95,56 @@ TEST(ShapeRuleTest, range_zero_delta_NEG) ASSERT_ANY_THROW(shape_inf_rule.infer(&range, shape)); } + +TEST(ShapeRuleTest, range_non_const_param) +{ + luci::CircleInput start, limit, delta; + luci::CircleRange range; + + start.dtype(loco::DataType::S32); + start.shape({1}); + start.shape_status(luci::ShapeStatus::VALID); + + limit.dtype(loco::DataType::S32); + limit.shape({1}); + limit.shape_status(luci::ShapeStatus::VALID); + + delta.dtype(loco::DataType::S32); + delta.shape({1}); + delta.shape_status(luci::ShapeStatus::VALID); + + range.start(&start); + range.limit(&limit); + range.delta(&delta); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&range, shape)); + ASSERT_EQ(1, shape.rank()); + ASSERT_FALSE(shape.dim(0).known()); + ASSERT_EQ(0, shape.dim(0).value()); +} + +TEST(ShapeRuleTest, range_nullptr_start_NEG) +{ + luci::CircleInput limit, delta; + luci::CircleRange range; + + limit.dtype(loco::DataType::S32); + limit.shape({1}); + limit.shape_status(luci::ShapeStatus::VALID); + + delta.dtype(loco::DataType::S32); + delta.shape({1}); + delta.shape_status(luci::ShapeStatus::VALID); + + range.start(nullptr); + range.limit(&limit); + range.delta(&delta); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_ANY_THROW(shape_inf_rule.infer(&range, shape)); +}