diff --git a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp index 76867ccafc1..43a8d9cd5c9 100644 --- a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp @@ -161,7 +161,7 @@ loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::Tensor return output_shape; } -loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleConst *paddings) +loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings) { const loco::DataType S32 = loco::DataType::S32; const loco::DataType S64 = loco::DataType::S64; @@ -180,6 +180,12 @@ loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::Ci loco::TensorShape output_shape; output_shape.rank(input_shape.rank()); + + auto const_padding = dynamic_cast(paddings); + + if (const_padding == nullptr) + return output_shape; + for (int32_t ni = 0; ni < n; ++ni) { if (not input_shape.dim(ni).known()) @@ -189,15 +195,15 @@ loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::Ci } int32_t idx = ni * 2; int value = input_shape.dim(ni).value(); - if (paddings->dtype() == S32) + if (const_padding->dtype() == S32) { - value += paddings->at(idx + 0); // left - value += paddings->at(idx + 1); // right + value += const_padding->at(idx + 0); // left + value += const_padding->at(idx + 1); // right } else { - auto pl = paddings->at(idx + 0); - auto pr = paddings->at(idx + 1); + auto pl = const_padding->at(idx + 0); + auto pr = const_padding->at(idx + 1); auto max = static_cast(std::numeric_limits::max()); auto low = static_cast(std::numeric_limits::lowest()); LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit"); diff --git a/compiler/luci/service/src/CircleShapeInferenceHelper.h b/compiler/luci/service/src/CircleShapeInferenceHelper.h index 4961e9c40de..a81114b016a 100644 --- a/compiler/luci/service/src/CircleShapeInferenceHelper.h +++ b/compiler/luci/service/src/CircleShapeInferenceHelper.h @@ -49,8 +49,8 @@ loco::TensorShape circle_shape(const luci::CircleNode *node); loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y); // Return shape of pad ops using paddings. -loco::TensorShape pad_shape(const loco::TensorShape &input_shape, - const luci::CircleConst *paddings); +// If paddings is static, return the shape filled with unknown dimensions. +loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings); /** * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics diff --git a/compiler/luci/service/src/Nodes/CirclePad.cpp b/compiler/luci/service/src/Nodes/CirclePad.cpp index 2f4f90140af..2589de57b20 100644 --- a/compiler/luci/service/src/Nodes/CirclePad.cpp +++ b/compiler/luci/service/src/Nodes/CirclePad.cpp @@ -32,8 +32,9 @@ namespace sinf loco::TensorShape Algorithm::visit(const luci::CirclePad *node) { - // TODO support non-const case - auto paddings = loco::must_cast(node->paddings()); + + auto paddings = loco::must_cast(node->paddings()); + auto circle_input = loco::must_cast(node->input()); auto input_shape = circle_shape(circle_input); return pad_shape(input_shape, paddings); diff --git a/compiler/luci/service/src/Nodes/CirclePad.test.cpp b/compiler/luci/service/src/Nodes/CirclePad.test.cpp index 996426164f3..b2b8e6fb5a4 100644 --- a/compiler/luci/service/src/Nodes/CirclePad.test.cpp +++ b/compiler/luci/service/src/Nodes/CirclePad.test.cpp @@ -90,3 +90,60 @@ TEST(ShapeRuleTest, pad_without_padding_NEG) ASSERT_ANY_THROW(shape_inf_rule.infer(node_pad, shape)); } + +TEST(ShapeRuleTest, pad_non_const_paddings) +{ + auto g = loco::make_graph(); + auto node_pad = g->nodes()->create(); + + auto node_paddings = g->nodes()->create(); + auto node_input = g->nodes()->create(); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + node_input->shape({1, 2, 3, 4}); + node_input->shape_status(luci::ShapeStatus::VALID); + + node_paddings->dtype(loco::DataType::S64); + node_paddings->shape({4, 2}); + node_paddings->shape_status(luci::ShapeStatus::VALID); + + node_pad->input(node_input); + node_pad->paddings(node_paddings); + + ASSERT_TRUE(shape_inf_rule.infer(node_pad, shape)); + ASSERT_EQ(shape.rank(), 4); + ASSERT_FALSE(shape.dim(0).known()); + ASSERT_FALSE(shape.dim(1).known()); + ASSERT_FALSE(shape.dim(2).known()); + ASSERT_FALSE(shape.dim(3).known()); + + ASSERT_EQ(0, shape.dim(0).value()); + ASSERT_EQ(0, shape.dim(1).value()); + ASSERT_EQ(0, shape.dim(2).value()); + ASSERT_EQ(0, shape.dim(3).value()); +} + +TEST(ShapeRuleTest, pad_without_input_NEG) +{ + auto g = loco::make_graph(); + auto node_pad = g->nodes()->create(); + + auto node_paddings = g->nodes()->create(); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + node_paddings->dtype(loco::DataType::S64); + node_paddings->shape({4, 2}); + node_paddings->shape_status(luci::ShapeStatus::VALID); + + const loco::DataType S64 = loco::DataType::S64; + uint32_t t = 64 * 8; + node_paddings->size(t); + + node_pad->paddings(node_paddings); + + ASSERT_ANY_THROW(shape_inf_rule.infer(node_pad, shape)); +}