-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add propagate_precision pass #2853
base: develop
Are you sure you want to change the base?
Changes from all commits
e8d3984
e521d2c
3a1f86a
4ea8c98
e2527b0
180ef75
b5b1e52
0d40cc9
b0b7cc8
b6a4b7e
da637b6
d1f9756
5bafe27
5381b68
2c490de
00df7bb
e6cf3df
3c39d1f
1a35ba5
da8471d
e2a69d2
5f7da8b
22e7d58
05e62d4
f25bd18
565bde1
0f785bd
8217360
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* The MIT License (MIT) | ||
* | ||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in | ||
* all copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
* THE SOFTWARE. | ||
* | ||
*/ | ||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP | ||
#define MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason why the include guard is named differently? |
||
|
||
#include <migraphx/config.hpp> | ||
#include <string> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
||
struct module_pass_manager; | ||
|
||
/// This pass will propagate higher precision through more adjacent operators. | ||
struct MIGRAPHX_EXPORT propagate_precision | ||
{ | ||
std::string name() const { return "propagate_precision"; } | ||
void apply(module_pass_manager& mpm) const; | ||
}; | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx | ||
#endif // MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -362,9 +362,9 @@ struct MIGRAPHX_EXPORT shape | |
|
||
std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; } | ||
|
||
auto is_integral() const { return std::is_integral<type>{}; } | ||
auto is_signed() const { return std::is_signed<type>{}; } | ||
auto is_unsigned() const { return std::is_unsigned<type>{}; } | ||
bool is_integral() const { return std::is_integral<type>{}; } | ||
bool is_signed() const { return std::is_signed<type>{}; } | ||
bool is_unsigned() const { return std::is_unsigned<type>{}; } | ||
Comment on lines
+365
to
+367
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is for readability/clarity. I don't see why these shouldn't resolve to anything but bool? Unless we want to use value here specified by stl? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing |
||
|
||
template <class U> | ||
type* from(U* buffer, std::size_t n = 0) const | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
/* | ||
* The MIT License (MIT) | ||
* | ||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in | ||
* all copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
* THE SOFTWARE. | ||
* | ||
*/ | ||
#include <migraphx/propagate_precision.hpp> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add license here too |
||
#include <migraphx/module.hpp> | ||
#include <migraphx/iterator_for.hpp> | ||
#include <migraphx/instruction.hpp> | ||
#include <migraphx/pass_manager.hpp> | ||
#include <migraphx/make_op.hpp> | ||
#include <migraphx/functional.hpp> | ||
#include <migraphx/ranges.hpp> | ||
#include <migraphx/eliminate_convert.hpp> | ||
#include <unordered_set> | ||
#include <unordered_map> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
||
namespace { | ||
#ifdef __clang__ | ||
#pragma clang diagnostic push | ||
#pragma clang diagnostic ignored "-Wunused-function" | ||
#endif | ||
// Class wrappper so we can compare precision using comparison operators | ||
struct precision | ||
{ | ||
shape::type_t type; | ||
|
||
friend bool operator==(const precision& xp, const precision& yp) { return xp.type == yp.type; } | ||
friend bool operator<(const precision& xp, const precision& yp) | ||
{ | ||
bool is_less = false; | ||
shape::visit(xp.type, [&](auto x) { | ||
shape::visit(yp.type, [&](auto y) { | ||
if(x.is_integral() != y.is_integral()) | ||
return; | ||
umangyadav marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if(x.is_integral()) | ||
{ | ||
if(x.is_unsigned() != y.is_unsigned() and x.size() == y.size()) | ||
is_less = y.is_unsigned(); | ||
else | ||
is_less = x.size() < y.size(); | ||
} | ||
else | ||
{ | ||
is_less = x.size() < y.size(); | ||
} | ||
}); | ||
}); | ||
return is_less; | ||
} | ||
friend bool operator!=(const precision& xp, const precision& yp) { return not(xp == yp); } | ||
friend bool operator>(const precision& xp, const precision& yp) { return yp < xp; } | ||
// This is not totally ordered | ||
friend bool operator<=(const precision& xp, const precision& yp) | ||
{ | ||
return (xp < yp) or (xp == yp); | ||
} | ||
friend bool operator>=(const precision& xp, const precision& yp) | ||
{ | ||
return (xp > yp) or (xp == yp); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might seem like an odd ask but why not make these xor vs or? If one is true then the other doesn't matter as the result shouldn't be true. Anyway adding xor instead of or here can speed things up / check for errors? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
}; | ||
#ifdef __clang__ | ||
#pragma clang diagnostic pop | ||
#endif | ||
} // namespace | ||
|
||
static bool is_pointwise_or_reduce(instruction_ref ins) | ||
{ | ||
return contains(ins->name(), "reduce") or | ||
ins->get_operator().attributes().get("pointwise", false); | ||
} | ||
// Check if its not a scalar constant | ||
static bool is_non_scalar_const(instruction_ref ins) | ||
{ | ||
return not(ins->get_shape().scalar() and ins->can_eval()); | ||
} | ||
// Get the next input instruction otherwise return a nullopt | ||
static std::optional<instruction_ref> get_next_input(instruction_ref ins) | ||
{ | ||
if(ins->inputs().size() == 1) | ||
return ins->inputs().front(); | ||
if(ins->inputs().size() > 1) | ||
{ | ||
std::unordered_set<instruction_ref> non_scalars; | ||
std::copy_if(ins->inputs().begin(), | ||
ins->inputs().end(), | ||
std::inserter(non_scalars, non_scalars.end()), | ||
&is_non_scalar_const); | ||
if(non_scalars.size() == 1) | ||
return *non_scalars.begin(); | ||
} | ||
return nullopt; | ||
} | ||
|
||
// Find all adjacent instructions that could be upgraded with higher precision | ||
// by traversing the inputs from a convert | ||
|
||
static std::unordered_set<instruction_ref> find_adjacent_inputs(instruction_ref start) | ||
{ | ||
std::unordered_set<instruction_ref> result; | ||
// Promote inputs | ||
fix([&](auto self, instruction_ref ins) { | ||
if(not is_pointwise_or_reduce(ins)) | ||
return; | ||
if(contains(result, ins)) | ||
return; | ||
auto next = get_next_input(ins); | ||
if(not next.has_value()) | ||
return; | ||
result.insert(ins); | ||
self(*next); | ||
})(start->inputs().front()); | ||
return result; | ||
} | ||
|
||
// Find all adjacent instructions that could be upgraded with higher precision | ||
// by traversing the outputs from a convert | ||
static std::unordered_set<instruction_ref> find_adjacent_outputs(instruction_ref start) | ||
{ | ||
std::unordered_set<instruction_ref> result; | ||
// Promote outputs | ||
fix([&](auto self, instruction_ref ins) { | ||
for(auto output : ins->outputs()) | ||
{ | ||
if(not is_pointwise_or_reduce(output)) | ||
continue; | ||
if(contains(result, output)) | ||
continue; | ||
auto next = get_next_input(output); | ||
if(not next.has_value()) | ||
continue; | ||
if(*next != ins) | ||
continue; | ||
result.insert(output); | ||
self(output); | ||
} | ||
})(start); | ||
return result; | ||
} | ||
|
||
// Insert the instructions to upgrade into the map. If the map already has the | ||
// instruction then choose the highest precision | ||
template <class Map, class Instructions> | ||
static void | ||
insert_instructions_to_upgrade(Map& m, const Instructions& instructions, shape::type_t t) | ||
{ | ||
for(auto ins : instructions) | ||
{ | ||
auto it = m.find(ins); | ||
if(it == m.end()) | ||
{ | ||
m[ins] = t; | ||
} | ||
else | ||
{ | ||
it->second = std::max(precision{t}, precision{it->second}).type; | ||
} | ||
} | ||
} | ||
|
||
// Find adjacent instructions from a convert to upgrade to use a higher | ||
// precision | ||
static std::unordered_map<instruction_ref, shape::type_t> find_instruction_to_upgrade(module& m) | ||
{ | ||
std::unordered_map<instruction_ref, shape::type_t> result; | ||
for(auto ins : iterator_for(m)) | ||
{ | ||
if(ins->name() != "convert") | ||
continue; | ||
auto output = precision{ins->get_shape().type()}; | ||
auto input = precision{ins->inputs().front()->get_shape().type()}; | ||
if(output.type == shape::type_t::bool_type) | ||
continue; | ||
if(input < output) | ||
{ | ||
insert_instructions_to_upgrade(result, find_adjacent_inputs(ins), output.type); | ||
} | ||
else if(input > output) | ||
{ | ||
insert_instructions_to_upgrade(result, find_adjacent_outputs(ins), input.type); | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
void propagate_precision::apply(module_pass_manager& mpm) const | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you write docstrings for all of these functions that describes what they are supposed to do and how they work? |
||
{ | ||
auto upgrade = find_instruction_to_upgrade(mpm.get_module()); | ||
for(const auto& p : upgrade) | ||
{ | ||
auto ins = p.first; | ||
auto t = p.second; | ||
auto convert1 = mpm.get_module().insert_instruction( | ||
std::next(ins), make_op("convert", {{"target_type", ins->get_shape().type()}}), ins); | ||
mpm.get_module().replace_instruction(ins, convert1); | ||
std::vector<instruction_ref> inputs; | ||
std::transform(ins->inputs().begin(), | ||
ins->inputs().end(), | ||
std::back_inserter(inputs), | ||
[&](auto input) { | ||
return mpm.get_module().insert_instruction( | ||
ins, make_op("convert", {{"target_type", t}}), input); | ||
}); | ||
mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs); | ||
} | ||
mpm.run_pass(eliminate_convert{}); | ||
} | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
License