diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1ffe2eb9f4e..5f8c029bbc2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -87,6 +87,7 @@ add_library(migraphx process.cpp program.cpp propagate_constant.cpp + propagate_precision.cpp promote_literals.cpp quantization.cpp quantize_fp16.cpp diff --git a/src/include/migraphx/propagate_precision.hpp b/src/include/migraphx/propagate_precision.hpp new file mode 100644 index 00000000000..2e839ac1354 --- /dev/null +++ b/src/include/migraphx/propagate_precision.hpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module_pass_manager; + +/// This pass will propagate higher precision through more adjacent operators. +struct MIGRAPHX_EXPORT propagate_precision +{ + std::string name() const { return "propagate_precision"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 4a5952d5eac..457f9ee9cbd 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -362,9 +362,9 @@ struct MIGRAPHX_EXPORT shape std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; } - auto is_integral() const { return std::is_integral{}; } - auto is_signed() const { return std::is_signed{}; } - auto is_unsigned() const { return std::is_unsigned{}; } + bool is_integral() const { return std::is_integral{}; } + bool is_signed() const { return std::is_signed{}; } + bool is_unsigned() const { return std::is_unsigned{}; } template type* from(U* buffer, std::size_t n = 0) const diff --git a/src/propagate_precision.cpp b/src/propagate_precision.cpp new file mode 100644 index 00000000000..4a025c6882b --- /dev/null +++ b/src/propagate_precision.cpp @@ -0,0 +1,233 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +namespace { +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-function" +#endif +// Class wrappper so we can compare precision using comparison operators +struct precision +{ + shape::type_t type; + + friend bool operator==(const precision& xp, const precision& yp) { return xp.type == yp.type; } + friend bool operator<(const precision& xp, const precision& yp) + { + bool is_less = false; + shape::visit(xp.type, [&](auto x) { + shape::visit(yp.type, [&](auto y) { + if(x.is_integral() != y.is_integral()) + return; + if(x.is_integral()) + { + if(x.is_unsigned() != y.is_unsigned() and x.size() == y.size()) + is_less = y.is_unsigned(); + else + is_less = x.size() < y.size(); + } + else + { + is_less = x.size() < y.size(); + } + }); + }); + return is_less; + } + friend bool operator!=(const precision& xp, const precision& yp) { return not(xp == yp); } + friend bool operator>(const precision& xp, const precision& yp) { return yp < xp; } + // This is not totally ordered + friend bool operator<=(const precision& xp, const precision& yp) + { + return (xp < yp) or (xp == yp); + } + friend bool operator>=(const precision& xp, const precision& yp) + { + return (xp > yp) or (xp == yp); + } +}; +#ifdef __clang__ +#pragma clang diagnostic pop +#endif +} // namespace + +static bool is_pointwise_or_reduce(instruction_ref ins) +{ + return contains(ins->name(), "reduce") or + ins->get_operator().attributes().get("pointwise", false); +} +// Check if its not a scalar constant +static bool is_non_scalar_const(instruction_ref ins) +{ + return not(ins->get_shape().scalar() and ins->can_eval()); +} +// Get the next input instruction otherwise return a nullopt +static std::optional get_next_input(instruction_ref ins) +{ + if(ins->inputs().size() == 1) + return ins->inputs().front(); + if(ins->inputs().size() > 1) + { + std::unordered_set non_scalars; + std::copy_if(ins->inputs().begin(), + ins->inputs().end(), + std::inserter(non_scalars, non_scalars.end()), + &is_non_scalar_const); + if(non_scalars.size() == 1) + return *non_scalars.begin(); + } + return nullopt; +} + +// Find all adjacent instructions that could be upgraded with higher precision +// by traversing the inputs from a convert + +static std::unordered_set find_adjacent_inputs(instruction_ref start) +{ + std::unordered_set result; + // Promote inputs + fix([&](auto self, instruction_ref ins) { + if(not is_pointwise_or_reduce(ins)) + return; + if(contains(result, ins)) + return; + auto next = get_next_input(ins); + if(not next.has_value()) + return; + result.insert(ins); + self(*next); + })(start->inputs().front()); + return result; +} + +// Find all adjacent instructions that could be upgraded with higher precision +// by traversing the outputs from a convert +static std::unordered_set find_adjacent_outputs(instruction_ref start) +{ + std::unordered_set result; + // Promote outputs + fix([&](auto self, instruction_ref ins) { + for(auto output : ins->outputs()) + { + if(not is_pointwise_or_reduce(output)) + continue; + if(contains(result, output)) + continue; + auto next = get_next_input(output); + if(not next.has_value()) + continue; + if(*next != ins) + continue; + result.insert(output); + self(output); + } + })(start); + return result; +} + +// Insert the instructions to upgrade into the map. If the map already has the +// instruction then choose the highest precision +template +static void +insert_instructions_to_upgrade(Map& m, const Instructions& instructions, shape::type_t t) +{ + for(auto ins : instructions) + { + auto it = m.find(ins); + if(it == m.end()) + { + m[ins] = t; + } + else + { + it->second = std::max(precision{t}, precision{it->second}).type; + } + } +} + +// Find adjacent instructions from a convert to upgrade to use a higher +// precision +static std::unordered_map find_instruction_to_upgrade(module& m) +{ + std::unordered_map result; + for(auto ins : iterator_for(m)) + { + if(ins->name() != "convert") + continue; + auto output = precision{ins->get_shape().type()}; + auto input = precision{ins->inputs().front()->get_shape().type()}; + if(output.type == shape::type_t::bool_type) + continue; + if(input < output) + { + insert_instructions_to_upgrade(result, find_adjacent_inputs(ins), output.type); + } + else if(input > output) + { + insert_instructions_to_upgrade(result, find_adjacent_outputs(ins), input.type); + } + } + return result; +} + +void propagate_precision::apply(module_pass_manager& mpm) const +{ + auto upgrade = find_instruction_to_upgrade(mpm.get_module()); + for(const auto& p : upgrade) + { + auto ins = p.first; + auto t = p.second; + auto convert1 = mpm.get_module().insert_instruction( + std::next(ins), make_op("convert", {{"target_type", ins->get_shape().type()}}), ins); + mpm.get_module().replace_instruction(ins, convert1); + std::vector inputs; + std::transform(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(inputs), + [&](auto input) { + return mpm.get_module().insert_instruction( + ins, make_op("convert", {{"target_type", t}}), input); + }); + mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs); + } + mpm.run_pass(eliminate_convert{}); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index b190d11a402..859b3cecfa8 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -191,6 +192,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti rewrite_reduce{}, rewrite_low_precision{}, dead_code_elimination{}, + propagate_precision{}, + dead_code_elimination{}, optimize_module{}, fuse_pointwise_reduce{}, enable_pass(enabled(MIGRAPHX_ENABLE_SPLIT_REDUCE{}), split_reduce{}), diff --git a/test/propagate_precision.cpp b/test/propagate_precision.cpp new file mode 100644 index 00000000000..da9a02a4aaa --- /dev/null +++ b/test/propagate_precision.cpp @@ -0,0 +1,182 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include + +#include + +void run_pass(migraphx::module& m) +{ + migraphx::run_passes(m, + {migraphx::propagate_precision{}, + migraphx::eliminate_common_subexpression{}, + migraphx::dead_code_elimination{}}); +} + +TEST_CASE(propagate_input) +{ + migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 3}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto two = m1.add_literal(migraphx::literal{{migraphx::shape::half_type}, {2}}); + auto div = migraphx::add_common_op(m1, migraphx::make_op("div"), {x, two}); + auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), div); + auto convert1 = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), sqrt); + auto mul = m1.add_instruction(migraphx::make_op("mul"), convert1, y); + auto convert2 = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), mul); + m1.add_return({convert2}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto convert1 = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x); + auto two = m2.add_literal(migraphx::literal{{migraphx::shape::half_type}, {2}}); + auto div = migraphx::add_common_op(m2, migraphx::make_op("div"), {convert1, two}); + auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), div); + auto mul = m2.add_instruction(migraphx::make_op("mul"), sqrt, y); + auto convert2 = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), mul); + m2.add_return({convert2}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(propagate_output) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::half_type, {2, 3}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto convert1 = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x); + auto two = m1.add_literal(migraphx::literal{{migraphx::shape::half_type}, {2}}); + auto div = migraphx::add_common_op(m1, migraphx::make_op("div"), {convert1, two}); + auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), div); + auto mul = m1.add_instruction(migraphx::make_op("mul"), sqrt, y); + m1.add_return({mul}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto two = m2.add_literal(migraphx::literal{{migraphx::shape::half_type}, {2}}); + auto div = migraphx::add_common_op(m2, migraphx::make_op("div"), {x, two}); + auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), div); + auto convert1 = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), sqrt); + auto mul = m2.add_instruction(migraphx::make_op("mul"), convert1, y); + m2.add_return({mul}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(propagate_conflict) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::double_type, {2, 3}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto convert1 = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x); + auto two = m1.add_literal(migraphx::literal{{migraphx::shape::half_type}, {2}}); + auto div = migraphx::add_common_op(m1, migraphx::make_op("div"), {convert1, two}); + auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), div); + auto convert2 = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::double_type}}), sqrt); + auto mul = m1.add_instruction(migraphx::make_op("mul"), convert2, y); + m1.add_return({mul}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto convert1 = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::double_type}}), x); + auto two = m2.add_literal(migraphx::literal{{migraphx::shape::half_type}, {2}}); + auto div = migraphx::add_common_op(m2, migraphx::make_op("div"), {convert1, two}); + auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), div); + auto mul = m2.add_instruction(migraphx::make_op("mul"), sqrt, y); + m2.add_return({mul}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(propagate_reduce) +{ + migraphx::shape s1{migraphx::shape::half_type, {2, 3}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto three = m1.add_literal(migraphx::literal{{migraphx::shape::half_type}, {3}}); + auto squared = m1.add_instruction(migraphx::make_op("mul"), x, x); + auto div = migraphx::add_common_op(m1, migraphx::make_op("div"), {squared, three}); + auto convert1 = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), div); + auto reduce = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", 1}}), convert1); + auto convert2 = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), reduce); + auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), convert2); + auto mul = migraphx::add_common_op(m1, migraphx::make_op("mul"), {x, sqrt}); + m1.add_return({mul}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto convert1 = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x); + auto three = m2.add_literal(migraphx::literal{{migraphx::shape::half_type}, {3}}); + auto squared = m2.add_instruction(migraphx::make_op("mul"), convert1, convert1); + auto div = migraphx::add_common_op(m2, migraphx::make_op("div"), {squared, three}); + auto reduce = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", 1}}), div); + auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), reduce); + auto convert2 = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), sqrt); + auto mul = migraphx::add_common_op(m2, migraphx::make_op("mul"), {x, convert2}); + m2.add_return({mul}); + } + EXPECT(m1.sort() == m2.sort()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); }