Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/service] Support shape inference for non-const paddings #13949

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions compiler/luci/service/src/CircleShapeInferenceHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::Tensor
return output_shape;
}

loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleConst *paddings)
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings)
{
const loco::DataType S32 = loco::DataType::S32;
const loco::DataType S64 = loco::DataType::S64;
Expand All @@ -180,6 +180,11 @@ loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::Ci
loco::TensorShape output_shape;

output_shape.rank(input_shape.rank());

auto const_padding = dynamic_cast<const luci::CircleConst *>(paddings);
if (const_padding == nullptr)
return output_shape;
shs-park marked this conversation as resolved.
Show resolved Hide resolved

for (int32_t ni = 0; ni < n; ++ni)
{
if (not input_shape.dim(ni).known())
Expand All @@ -189,15 +194,15 @@ loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::Ci
}
int32_t idx = ni * 2;
int value = input_shape.dim(ni).value();
if (paddings->dtype() == S32)
if (const_padding->dtype() == S32)
{
value += paddings->at<S32>(idx + 0); // left
value += paddings->at<S32>(idx + 1); // right
value += const_padding->at<S32>(idx + 0); // left
value += const_padding->at<S32>(idx + 1); // right
}
else
{
auto pl = paddings->at<S64>(idx + 0);
auto pr = paddings->at<S64>(idx + 1);
auto pl = const_padding->at<S64>(idx + 0);
auto pr = const_padding->at<S64>(idx + 1);
auto max = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
auto low = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
LUCI_ASSERT(pl <= max, "paddings is over 32 bit limit");
Expand Down
4 changes: 2 additions & 2 deletions compiler/luci/service/src/CircleShapeInferenceHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ loco::TensorShape circle_shape(const luci::CircleNode *node);
loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y);

// Return shape of pad ops using paddings.
loco::TensorShape pad_shape(const loco::TensorShape &input_shape,
const luci::CircleConst *paddings);
// If paddings is not static, return the shape filled with unknown dimensions.
loco::TensorShape pad_shape(const loco::TensorShape &input_shape, const luci::CircleNode *paddings);

/**
* @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
Expand Down
4 changes: 2 additions & 2 deletions compiler/luci/service/src/Nodes/CirclePad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ namespace sinf

loco::TensorShape Algorithm::visit(const luci::CirclePad *node)
{
// TODO support non-const case
auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
auto paddings = loco::must_cast<const luci::CircleNode *>(node->paddings());
auto circle_input = loco::must_cast<const luci::CircleNode *>(node->input());
auto input_shape = circle_shape(circle_input);

icodo98 marked this conversation as resolved.
Show resolved Hide resolved
return pad_shape(input_shape, paddings);
}

Expand Down
53 changes: 53 additions & 0 deletions compiler/luci/service/src/Nodes/CirclePad.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,56 @@ TEST(ShapeRuleTest, pad_without_padding_NEG)

ASSERT_ANY_THROW(shape_inf_rule.infer(node_pad, shape));
}

TEST(ShapeRuleTest, pad_non_const_paddings)
shs-park marked this conversation as resolved.
Show resolved Hide resolved
{
auto g = loco::make_graph();
auto node_pad = g->nodes()->create<luci::CirclePad>();

auto node_paddings = g->nodes()->create<luci::CircleInput>();
auto node_input = g->nodes()->create<luci::CircleInput>();

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

node_input->shape({1, 2, 3, 4});
node_input->shape_status(luci::ShapeStatus::VALID);

node_paddings->dtype(loco::DataType::S64);
node_paddings->shape({4, 2});
node_paddings->shape_status(luci::ShapeStatus::VALID);

node_pad->input(node_input);
node_pad->paddings(node_paddings);

ASSERT_TRUE(shape_inf_rule.infer(node_pad, shape));
ASSERT_EQ(shape.rank(), 4);
ASSERT_FALSE(shape.dim(0).known());
ASSERT_FALSE(shape.dim(1).known());
ASSERT_FALSE(shape.dim(2).known());
ASSERT_FALSE(shape.dim(3).known());

ASSERT_EQ(0, shape.dim(0).value());
ASSERT_EQ(0, shape.dim(1).value());
ASSERT_EQ(0, shape.dim(2).value());
ASSERT_EQ(0, shape.dim(3).value());
}

TEST(ShapeRuleTest, pad_empty_padding_NEG)
shs-park marked this conversation as resolved.
Show resolved Hide resolved
{
auto g = loco::make_graph();
auto node_pad = g->nodes()->create<luci::CirclePad>();
auto node_input = g->nodes()->create<luci::CircleInput>();
auto node_paddings = g->nodes()->create<luci::CircleAdd>();
icodo98 marked this conversation as resolved.
Show resolved Hide resolved

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

node_input->shape({1, 2, 3, 4});
node_input->shape_status(luci::ShapeStatus::VALID);

node_pad->input(node_input);
node_pad->paddings(node_paddings);

ASSERT_FALSE(shape_inf_rule.infer(node_pad, shape));
}
icodo98 marked this conversation as resolved.
Show resolved Hide resolved