diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 334c9615f68..a8addd1d255 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -96,6 +96,7 @@ add_library(migraphx register_target.cpp replace_allocate.cpp rewrite_reduce.cpp + rewrite_dot.cpp simplify_qdq.cpp split_reduce.cpp sqlite.cpp diff --git a/src/include/migraphx/rewrite_dot.hpp b/src/include/migraphx/rewrite_dot.hpp new file mode 100644 index 00000000000..f756cdf62b0 --- /dev/null +++ b/src/include/migraphx/rewrite_dot.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_REWRITE_DOT_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_REWRITE_DOT_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +struct MIGRAPHX_EXPORT rewrite_dot +{ + std::string name() const { return "rewrite_dot"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_REWRITE_DOT_HPP diff --git a/src/rewrite_dot.cpp b/src/rewrite_dot.cpp new file mode 100644 index 00000000000..50f5cfdbd41 --- /dev/null +++ b/src/rewrite_dot.cpp @@ -0,0 +1,122 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REWRITE_DOT); + +namespace { + +MIGRAPHX_PRED_MATCHER(conv_1x1, instruction_ref ins) +{ + if(ins->name() != "convolution") + return false; + auto v = ins->get_operator().to_value(); + if(v.at("group").to() != 1) + return false; + if(not all_of(v.at("stride"), [](const value& x) { return x.to() == 1; })) + return false; + if(not all_of(v.at("padding"), [](const value& x) { return x.to() == 0; })) + return false; + if(not all_of(v.at("dilation"), [](const value& x) { return x.to() == 1; })) + return false; + auto w = ins->inputs().at(1)->get_shape(); + return std::all_of(w.lens().begin() + 2, w.lens().end(), [](std::size_t i) { return i == 1; }); +} + +struct find_1x1_convolution +{ + auto matcher() const { return conv_1x1(match::arg(1)(match::is_constant())); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + + auto input = ins->inputs().front(); + auto weights = ins->inputs().back(); + + std::vector sq_axes(ins->get_shape().ndim() - 2); + std::iota(sq_axes.begin(), sq_axes.end(), 2); + auto sq_weights = + m.insert_instruction(ins, make_op("squeeze", {{"axes", sq_axes}}), weights); + + if(ins->get_shape().transposed()) + { + std::vector aperm(ins->get_shape().ndim()); + std::iota(aperm.begin(), aperm.end(), 0); + std::rotate(aperm.begin() + 1, aperm.begin() + 2, aperm.end()); + auto a_mat = + m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); + + auto transpose = m.insert_instruction( + ins, make_op("transpose", {{"permutation", {1, 0}}}), sq_weights); + auto b_lens = a_mat->get_shape().lens(); + copy(transpose->get_shape().lens(), b_lens.end() - 2); + auto b_mat = m.insert_instruction( + ins, make_op("multibroadcast", {{"out_lens", b_lens}}), transpose); + + auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); + m.replace_instruction( + ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), dot); + } + else + { + auto batch_dim = ins->get_shape().lens().front(); + auto m_dim = std::accumulate(input->get_shape().lens().begin() + 2, + input->get_shape().lens().end(), + 1, + std::multiplies<>{}); + auto n_dim = weights->get_shape().lens()[0]; + auto k_dim = weights->get_shape().lens()[1]; + auto a_mat = m.insert_instruction( + ins, + make_op("multibroadcast", {{"out_lens", {batch_dim, n_dim, k_dim}}}), + sq_weights); + auto b_mat = m.insert_instruction( + ins, make_op("reshape", {{"dims", {batch_dim, k_dim, m_dim}}}), input); + auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); + m.replace_instruction( + ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), dot); + } + } +}; + +} // namespace + +void rewrite_dot::apply(module& m) const +{ + if(enabled(MIGRAPHX_DISABLE_REWRITE_DOT{})) + return; + match::find_matches(m, find_1x1_convolution{}); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 83b602cdef4..a08717ada4d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -276,13 +276,18 @@ auto is_mlir_dot(mlir_mode mode) return true; auto a = ins->inputs().front()->get_shape(); auto b = ins->inputs().back()->get_shape(); - // auto m = a.lens()[a.lens().size() - 2]; - // auto n = b.lens().back(); + auto g = std::accumulate(a.lens().begin(), a.lens().end() - 2, 1, std::multiplies<>{}); + auto m = a.lens()[a.lens().size() - 2]; + auto n = b.lens().back(); auto k = a.lens().back(); // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // to avoid poor-performing GEMM kernels from MLIR // To-do: Investigate a more precise strategy - return k <= 1024; + if (k > 1535) + return false; + if (k < 1024) + return true; + return (g*m*n) < (384*384); }); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index a0edac5eb17..2bacd458608 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -43,6 +43,7 @@ #include #include #include +#include #include #include #include @@ -198,6 +199,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, rewrite_reduce{}, rewrite_low_precision{}, + rewrite_dot{}, dead_code_elimination{}, optimize_module{}, fuse_pointwise_reduce{}, diff --git a/test/rewrite_dot.cpp b/test/rewrite_dot.cpp new file mode 100644 index 00000000000..e680d84f09b --- /dev/null +++ b/test/rewrite_dot.cpp @@ -0,0 +1,95 @@ + +#include +#include +#include +#include +#include +#include +#include + +void run_pass(migraphx::module& m) +{ + migraphx::run_passes(m, {migraphx::rewrite_dot{}, migraphx::dead_code_elimination{}}); +} + +TEST_CASE(nchw_conv_1x1) +{ + migraphx::shape s1{migraphx::shape::float_type, {64, 128, 28, 28}}; + migraphx::shape s2{migraphx::shape::float_type, {512, 128, 1, 1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto w = m1.add_literal(migraphx::generate_literal(s2)); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + m1.add_return({conv}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto w = m2.add_literal(migraphx::generate_literal(s2)); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), w); + auto broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {64, 512, 128}}}), squeeze); + auto reshape1 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 128, 784}}}), x); + auto dot = m2.add_instruction(migraphx::make_op("dot"), broadcast, reshape1); + auto reshape2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 512, 28, 28}}}), dot); + m2.add_return({reshape2}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_conv_1x1) +{ + auto s1 = migraphx::shape::from_permutation( + migraphx::shape::float_type, {64, 128, 28, 28}, {0, 2, 3, 1}); + auto s2 = migraphx::shape::from_permutation( + migraphx::shape::float_type, {512, 128, 1, 1}, {0, 2, 3, 1}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto w = m1.add_literal(migraphx::generate_literal(s2)); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + m1.add_return({conv}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto w = m2.add_literal(migraphx::generate_literal(s2)); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), w); + auto transpose1 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), squeeze); + auto broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {64, 28, 128, 512}}}), transpose1); + auto transpose2 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto dot = m2.add_instruction(migraphx::make_op("dot"), transpose2, broadcast); + auto transpose3 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), dot); + m2.add_return({transpose3}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_group_conv_1x1) +{ + auto s1 = migraphx::shape::from_permutation( + migraphx::shape::float_type, {64, 192, 83, 83}, {0, 2, 3, 1}); + auto s2 = migraphx::shape::from_permutation( + migraphx::shape::float_type, {84, 96, 1, 1}, {0, 2, 3, 1}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto w = m1.add_literal(migraphx::generate_literal(s2)); + auto conv = m1.add_instruction(migraphx::make_op("convolution", {{"group", 2}}), x, w); + m1.add_return({conv}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); }