Skip to content

Commit

Permalink
skip_mul_add_for_horizontal_add_fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
aarushjain29 committed Nov 20, 2024
1 parent 4aaa2b9 commit a979c11
Showing 1 changed file with 160 additions and 3 deletions.
163 changes: 160 additions & 3 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4133,7 +4133,7 @@ TEST_CASE(dot_slice_a)
};
EXPECT(m1.sort() == m2.sort());
}

/*
TEST_CASE(my_optim)
{
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}};
Expand Down Expand Up @@ -4174,6 +4174,162 @@ TEST_CASE(my_optim)
}
*/
/*
TEST_CASE(my_optim)
{
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}};
migraphx::module m2;
{
// Parameter
auto a = m2.add_parameter("b", as);
// Slicing the input into three parts
auto slice_a = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a);
auto slice_b = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a);
auto slice_c = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a);
// Literals properly broadcasted to match slice shapes
auto one = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f));
auto four = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
// Operations
auto add1 = m2.add_instruction(migraphx::make_op("add"), slice_a, one);
auto mul1 = m2.add_instruction(migraphx::make_op("mul"), add1, four);
auto add2 = m2.add_instruction(migraphx::make_op("add"), slice_b, two);
auto add3 = m2.add_instruction(migraphx::make_op("add"), slice_c, three);
// Returning results
m2.add_return({mul1, add2, add3});
};
// Copying and running optimization pass
migraphx::module m1 = m2;
run_pass(m2);
// Validate that the optimized module is equivalent to the original
EXPECT(m1.sort() == m2.sort());
}*/

/*
TEST_CASE(my_optim)
{
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}};
migraphx::module m2;
{
// Parameter
auto a = m2.add_parameter("b", as);
// Slicing the input into three parts
auto slice_a = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a);
auto slice_b = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a);
auto slice_c = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a);
// Literals properly broadcasted to match slice shapes
auto one = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 4.0f));
auto four = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
// Operations
auto add1 = m2.add_instruction(migraphx::make_op("add"), slice_a, one);
auto mul1 = m2.add_instruction(migraphx::make_op("mul"), add1, four);
auto add2 = m2.add_instruction(migraphx::make_op("add"), slice_b, two);
auto add3 = m2.add_instruction(migraphx::make_op("add"), slice_c, three);
// Returning results
m2.add_return({mul1, add2, add3});
};
// Copying and running optimization pass
migraphx::module m1 = m2;
run_pass(m2);
// Validate that the optimized module is equivalent to the original
EXPECT(m2==m1);
}
*/


TEST_CASE(skip_mul_add_for_horizontal_add_fusion)
{
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}};

migraphx::module m1;
{
auto a = m1.add_parameter("a", as);

auto slice_a = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a);
auto slice_b = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a);
auto slice_c = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a);

auto one = m1.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m1.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m1.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 4.0f));

auto add1 = m1.add_instruction(migraphx::make_op("add"), slice_a, one);

auto add2 = m1.add_instruction(migraphx::make_op("add"), slice_b, two);
auto add3 = m1.add_instruction(migraphx::make_op("add"), slice_c, three);

m1.add_return({add1});

};

run_pass(m1);

migraphx::module m2;
{
auto a = m2.add_parameter("a", as);

auto one = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 4.0f));

auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), one, two, three);

auto add = m2.add_instruction(migraphx::make_op("add"), a, concat);

auto slice_a = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), add);

m2.add_return({slice_a});
};

EXPECT(m1.sort() == m2.sort());
}

Check warning on line 4331 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

Value stored to 'add2' during its initialization is never read [clang-analyzer-deadcode.DeadStores,-warnings-as-errors]

Check warning on line 4331 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

unused variable 'add2' [clang-diagnostic-unused-variable,-warnings-as-errors]

Check warning on line 4332 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

Value stored to 'add3' during its initialization is never read [clang-analyzer-deadcode.DeadStores,-warnings-as-errors]

Check warning on line 4332 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

unused variable 'add3' [clang-diagnostic-unused-variable,-warnings-as-errors]
TEST_CASE(dot_slice_ab)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand Down Expand Up @@ -4240,7 +4396,7 @@ TEST_CASE(dot_slice_batch_dims)
TEST_CASE(complex_graph_operations)
{
migraphx::module m;

auto x_0 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 0));
auto x_1 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 1));
auto x_2 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 2));
Expand Down Expand Up @@ -4305,7 +4461,8 @@ TEST_CASE(dot_slice_not_applicable_1)
{{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}),
dot);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), dot);
migraphx::make_op("slice", {{"axes", {0}}, {"starts"
, {0}}, {"ends", {1}}}), dot);

m1.add_return({slice1, slice2});
};
Expand Down

0 comments on commit a979c11

Please sign in to comment.