Skip to content

Commit

Permalink
Fuse unpack_int4 across concat (#3641)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Nov 21, 2024
1 parent 952a257 commit 0ad8073
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,8 @@ struct find_concat_op
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::any_of(match::pointwise(),
match::name("broadcast", "multibroadcast", "unpack_int4")),
match::used_once()));
}

Expand All @@ -890,7 +891,7 @@ struct find_concat_op

static bool is_valid_op(const operation& op)
{
return contains({"broadcast", "multibroadcast"}, op.name()) or
return contains({"broadcast", "multibroadcast", "unpack_int4"}, op.name()) or
op.attributes().contains("pointwise");
}

Expand Down
26 changes: 26 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,32 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}

TEST_CASE(simplify_concat_unpack_int4)
{
auto s = migraphx::shape{migraphx::shape::int8_type, {11008, 2048}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto unpack1 = m1.add_instruction(migraphx::make_op("unpack_int4"), x);
auto unpack2 = m1.add_instruction(migraphx::make_op("unpack_int4"), y);
auto concat =
m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), unpack1, unpack2);
m1.add_return({concat});
}
run_pass(m1);

migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
auto unpack = m2.add_instruction(migraphx::make_op("unpack_int4"), concat);
m2.add_return({unpack});
}
EXPECT(m1 == m2);
}

TEST_CASE(simplify_concat_add_relu)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
Expand Down

0 comments on commit 0ad8073

Please sign in to comment.