diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index dc61518bbca..b45444def76 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -4133,7 +4133,7 @@ TEST_CASE(dot_slice_a) }; EXPECT(m1.sort() == m2.sort()); } - +/* TEST_CASE(my_optim) { migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}}; @@ -4174,6 +4174,162 @@ TEST_CASE(my_optim) } +*/ +/* +TEST_CASE(my_optim) +{ + migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}}; + + migraphx::module m2; + { + // Parameter + auto a = m2.add_parameter("b", as); + + // Slicing the input into three parts + auto slice_a = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a); + auto slice_b = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a); + auto slice_c = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a); + + // Literals properly broadcasted to match slice shapes + auto one = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 3.0f)); + auto two = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f)); + auto three = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f)); + auto four = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f)); + + // Operations + auto add1 = m2.add_instruction(migraphx::make_op("add"), slice_a, one); + auto mul1 = m2.add_instruction(migraphx::make_op("mul"), add1, four); + + auto add2 = m2.add_instruction(migraphx::make_op("add"), slice_b, two); + auto add3 = m2.add_instruction(migraphx::make_op("add"), slice_c, three); + + // Returning results + m2.add_return({mul1, add2, add3}); + }; + + // Copying and running optimization pass + migraphx::module m1 = m2; + run_pass(m2); + + // Validate that the optimized module is equivalent to the original + EXPECT(m1.sort() == m2.sort()); +}*/ + +/* +TEST_CASE(my_optim) +{ + migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}}; + + migraphx::module m2; + { + // Parameter + auto a = m2.add_parameter("b", as); + + // Slicing the input into three parts + auto slice_a = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a); + auto slice_b = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a); + auto slice_c = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a); + + // Literals properly broadcasted to match slice shapes + auto one = m2.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 3.0f)); + auto two = m2.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 2.0f)); + auto three = m2.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 4.0f)); + auto four = m2.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 2.0f)); + + // Operations + auto add1 = m2.add_instruction(migraphx::make_op("add"), slice_a, one); + auto mul1 = m2.add_instruction(migraphx::make_op("mul"), add1, four); + + auto add2 = m2.add_instruction(migraphx::make_op("add"), slice_b, two); + auto add3 = m2.add_instruction(migraphx::make_op("add"), slice_c, three); + + // Returning results + m2.add_return({mul1, add2, add3}); + }; + + // Copying and running optimization pass + migraphx::module m1 = m2; + run_pass(m2); + + + + // Validate that the optimized module is equivalent to the original + EXPECT(m2==m1); +} +*/ + + +TEST_CASE(skip_mul_add_for_horizontal_add_fusion) +{ + migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}}; + + migraphx::module m1; + { + auto a = m1.add_parameter("a", as); + + auto slice_a = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a); + auto slice_b = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a); + auto slice_c = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a); + + auto one = m1.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 3.0f)); + auto two = m1.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 2.0f)); + auto three = m1.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 4.0f)); + + auto add1 = m1.add_instruction(migraphx::make_op("add"), slice_a, one); + + auto add2 = m1.add_instruction(migraphx::make_op("add"), slice_b, two); + auto add3 = m1.add_instruction(migraphx::make_op("add"), slice_c, three); + + m1.add_return({add1}); + + }; + + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_parameter("a", as); + + auto one = m2.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 3.0f)); + auto two = m2.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 2.0f)); + auto three = m2.add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {1, 77, 768}}, 4.0f)); + + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), one, two, three); + + auto add = m2.add_instruction(migraphx::make_op("add"), a, concat); + + auto slice_a = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), add); + + m2.add_return({slice_a}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(dot_slice_ab) { migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; @@ -4240,7 +4396,7 @@ TEST_CASE(dot_slice_batch_dims) TEST_CASE(complex_graph_operations) { migraphx::module m; - + auto x_0 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 0)); auto x_1 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 1)); auto x_2 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 2)); @@ -4305,7 +4461,8 @@ TEST_CASE(dot_slice_not_applicable_1) {{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}), dot); auto slice2 = m1.add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), dot); + migraphx::make_op("slice", {{"axes", {0}}, {"starts" + , {0}}, {"ends", {1}}}), dot); m1.add_return({slice1, slice2}); };