Skip to content

Commit

Permalink
[luci/service] Support shape inference for non-const paddings (#13949)
Browse files Browse the repository at this point in the history
This PR supports infer shape of pad operation with dynamic paddings.

ONE-DCO-1.0-Signed-off-by: JuYoung Lee [email protected]
Co-authored-by: SaeHie Park <[email protected]>
  • Loading branch information
icodo98 and seanshpark authored Sep 11, 2024
1 parent 2693ed2 commit 0f8808e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 10 deletions.
17 changes: 11 additions & 6 deletions compiler/luci/service/src/CircleShapeInferenceHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::Tensor
return output_shape;
}

loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleConst *paddings)
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings)
{
const loco::DataType S32 = loco::DataType::S32;
const loco::DataType S64 = loco::DataType::S64;
Expand All @@ -180,6 +180,11 @@ loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::Ci
loco::TensorShape output_shape;

output_shape.rank(input_shape.rank());

auto const_padding = dynamic_cast<const luci::CircleConst *>(paddings);
if (const_padding == nullptr)
return output_shape;

for (int32_t ni = 0; ni < n; ++ni)
{
if (not input_shape.dim(ni).known())
Expand All @@ -189,15 +194,15 @@ loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::Ci
}
int32_t idx = ni * 2;
int value = input_shape.dim(ni).value();
if (paddings->dtype() == S32)
if (const_padding->dtype() == S32)
{
value += paddings->at<S32>(idx + 0); // left
value += paddings->at<S32>(idx + 1); // right
value += const_padding->at<S32>(idx + 0); // left
value += const_padding->at<S32>(idx + 1); // right
}
else
{
auto pl = paddings->at<S64>(idx + 0);
auto pr = paddings->at<S64>(idx + 1);
auto pl = const_padding->at<S64>(idx + 0);
auto pr = const_padding->at<S64>(idx + 1);
auto max = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
auto low = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit");
Expand Down
4 changes: 2 additions & 2 deletions compiler/luci/service/src/CircleShapeInferenceHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ loco::TensorShape circle_shape(const luci::CircleNode *node);
loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y);

// Return shape of pad ops using paddings.
loco::TensorShape pad_shape(const loco::TensorShape &input_shape,
const luci::CircleConst *paddings);
// If paddings is not static, return the shape filled with unknown dimensions.
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings);

/**
* @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
Expand Down
3 changes: 1 addition & 2 deletions compiler/luci/service/src/Nodes/CirclePad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ namespace sinf

loco::TensorShape Algorithm::visit(const luci::CirclePad *node)
{
// TODO support non-const case
auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
auto paddings = loco::must_cast<const luci::CircleNode *>(node->paddings());
auto circle_input = loco::must_cast<const luci::CircleNode *>(node->input());
auto input_shape = circle_shape(circle_input);
return pad_shape(input_shape, paddings);
Expand Down
34 changes: 34 additions & 0 deletions compiler/luci/service/src/Nodes/CirclePad.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,37 @@ TEST(ShapeRuleTest, pad_without_padding_NEG)

ASSERT_ANY_THROW(shape_inf_rule.infer(node_pad, shape));
}

TEST(ShapeRuleTest, pad_non_const_paddings)
{
auto g = loco::make_graph();
auto node_pad = g->nodes()->create<luci::CirclePad>();

auto node_paddings = g->nodes()->create<luci::CircleInput>();
auto node_input = g->nodes()->create<luci::CircleInput>();

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

node_input->shape({1, 2, 3, 4});
node_input->shape_status(luci::ShapeStatus::VALID);

node_paddings->dtype(loco::DataType::S64);
node_paddings->shape({4, 2});
node_paddings->shape_status(luci::ShapeStatus::VALID);

node_pad->input(node_input);
node_pad->paddings(node_paddings);

ASSERT_TRUE(shape_inf_rule.infer(node_pad, shape));
ASSERT_EQ(shape.rank(), 4);
ASSERT_FALSE(shape.dim(0).known());
ASSERT_FALSE(shape.dim(1).known());
ASSERT_FALSE(shape.dim(2).known());
ASSERT_FALSE(shape.dim(3).known());

ASSERT_EQ(0, shape.dim(0).value());
ASSERT_EQ(0, shape.dim(1).value());
ASSERT_EQ(0, shape.dim(2).value());
ASSERT_EQ(0, shape.dim(3).value());
}

0 comments on commit 0f8808e

Please sign in to comment.