Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add propagate_precision pass #2853

Open
wants to merge 28 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ add_library(migraphx
process.cpp
program.cpp
propagate_constant.cpp
propagate_precision.cpp
promote_literals.cpp
quantization.cpp
quantize_fp16.cpp
Expand Down
45 changes: 45 additions & 0 deletions src/include/migraphx/propagate_precision.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

License

#define MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why the include guard is named differently?


#include <migraphx/config.hpp>
#include <string>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module_pass_manager;

/// This pass will propagate higher precision through more adjacent operators.
struct MIGRAPHX_EXPORT propagate_precision
{
std::string name() const { return "propagate_precision"; }
void apply(module_pass_manager& mpm) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP
6 changes: 3 additions & 3 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,9 @@ struct MIGRAPHX_EXPORT shape

std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }

auto is_integral() const { return std::is_integral<type>{}; }
auto is_signed() const { return std::is_signed<type>{}; }
auto is_unsigned() const { return std::is_unsigned<type>{}; }
bool is_integral() const { return std::is_integral<type>{}; }
bool is_signed() const { return std::is_signed<type>{}; }
bool is_unsigned() const { return std::is_unsigned<type>{}; }
Comment on lines +365 to +367
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why auto was changed to bool ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is for readability/clarity. I don't see why these shouldn't resolve to anything but bool? Unless we want to use value here specified by stl?

https://en.cppreference.com/w/cpp/types/is_unsigned

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing x.is_integral() != y.is_integral() will fail to compile because they will be different types using auto. So I explicitly convert it to bool instead.


template <class U>
type* from(U* buffer, std::size_t n = 0) const
Expand Down
233 changes: 233 additions & 0 deletions src/propagate_precision.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#include <migraphx/propagate_precision.hpp>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add license here too

#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/eliminate_convert.hpp>
#include <unordered_set>
#include <unordered_map>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace {
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-function"
#endif
// Class wrappper so we can compare precision using comparison operators
struct precision
{
shape::type_t type;

friend bool operator==(const precision& xp, const precision& yp) { return xp.type == yp.type; }
friend bool operator<(const precision& xp, const precision& yp)
{
bool is_less = false;
shape::visit(xp.type, [&](auto x) {
shape::visit(yp.type, [&](auto y) {
if(x.is_integral() != y.is_integral())
return;
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
if(x.is_integral())
{
if(x.is_unsigned() != y.is_unsigned() and x.size() == y.size())
is_less = y.is_unsigned();
else
is_less = x.size() < y.size();
}
else
{
is_less = x.size() < y.size();
}
});
});
return is_less;
}
friend bool operator!=(const precision& xp, const precision& yp) { return not(xp == yp); }
friend bool operator>(const precision& xp, const precision& yp) { return yp < xp; }
// This is not totally ordered
friend bool operator<=(const precision& xp, const precision& yp)
{
return (xp < yp) or (xp == yp);
}
friend bool operator>=(const precision& xp, const precision& yp)
{
return (xp > yp) or (xp == yp);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might seem like an odd ask but why not make these xor vs or? If one is true then the other doesn't matter as the result shouldn't be true. Anyway adding xor instead of or here can speed things up / check for errors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xor wont short circuit.

};
#ifdef __clang__
#pragma clang diagnostic pop
#endif
} // namespace

static bool is_pointwise_or_reduce(instruction_ref ins)
{
return contains(ins->name(), "reduce") or
ins->get_operator().attributes().get("pointwise", false);
}
// Check if its not a scalar constant
static bool is_non_scalar_const(instruction_ref ins)
{
return not(ins->get_shape().scalar() and ins->can_eval());
}
// Get the next input instruction otherwise return a nullopt
static std::optional<instruction_ref> get_next_input(instruction_ref ins)
{
if(ins->inputs().size() == 1)
return ins->inputs().front();
if(ins->inputs().size() > 1)
{
std::unordered_set<instruction_ref> non_scalars;
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::inserter(non_scalars, non_scalars.end()),
&is_non_scalar_const);
if(non_scalars.size() == 1)
return *non_scalars.begin();
}
return nullopt;
}

// Find all adjacent instructions that could be upgraded with higher precision
// by traversing the inputs from a convert

static std::unordered_set<instruction_ref> find_adjacent_inputs(instruction_ref start)
{
std::unordered_set<instruction_ref> result;
// Promote inputs
fix([&](auto self, instruction_ref ins) {
if(not is_pointwise_or_reduce(ins))
return;
if(contains(result, ins))
return;
auto next = get_next_input(ins);
if(not next.has_value())
return;
result.insert(ins);
self(*next);
})(start->inputs().front());
return result;
}

// Find all adjacent instructions that could be upgraded with higher precision
// by traversing the outputs from a convert
static std::unordered_set<instruction_ref> find_adjacent_outputs(instruction_ref start)
{
std::unordered_set<instruction_ref> result;
// Promote outputs
fix([&](auto self, instruction_ref ins) {
for(auto output : ins->outputs())
{
if(not is_pointwise_or_reduce(output))
continue;
if(contains(result, output))
continue;
auto next = get_next_input(output);
if(not next.has_value())
continue;
if(*next != ins)
continue;
result.insert(output);
self(output);
}
})(start);
return result;
}

// Insert the instructions to upgrade into the map. If the map already has the
// instruction then choose the highest precision
template <class Map, class Instructions>
static void
insert_instructions_to_upgrade(Map& m, const Instructions& instructions, shape::type_t t)
{
for(auto ins : instructions)
{
auto it = m.find(ins);
if(it == m.end())
{
m[ins] = t;
}
else
{
it->second = std::max(precision{t}, precision{it->second}).type;
}
}
}

// Find adjacent instructions from a convert to upgrade to use a higher
// precision
static std::unordered_map<instruction_ref, shape::type_t> find_instruction_to_upgrade(module& m)
{
std::unordered_map<instruction_ref, shape::type_t> result;
for(auto ins : iterator_for(m))
{
if(ins->name() != "convert")
continue;
auto output = precision{ins->get_shape().type()};
auto input = precision{ins->inputs().front()->get_shape().type()};
if(output.type == shape::type_t::bool_type)
continue;
if(input < output)
{
insert_instructions_to_upgrade(result, find_adjacent_inputs(ins), output.type);
}
else if(input > output)
{
insert_instructions_to_upgrade(result, find_adjacent_outputs(ins), input.type);
}
}
return result;
}

void propagate_precision::apply(module_pass_manager& mpm) const
Copy link
Member

@umangyadav umangyadav Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you write docstrings for all of these functions that describes what they are supposed to do and how they work?
Also add how the pass is supposed to work and how it helps with precision or accuracy or performance ?
We can read the code but it is not time efficient for all to get high level understanding.

{
auto upgrade = find_instruction_to_upgrade(mpm.get_module());
for(const auto& p : upgrade)
{
auto ins = p.first;
auto t = p.second;
auto convert1 = mpm.get_module().insert_instruction(
std::next(ins), make_op("convert", {{"target_type", ins->get_shape().type()}}), ins);
mpm.get_module().replace_instruction(ins, convert1);
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) {
return mpm.get_module().insert_instruction(
ins, make_op("convert", {{"target_type", t}}), input);
});
mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs);
}
mpm.run_pass(eliminate_convert{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
3 changes: 3 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <migraphx/optimize_module.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/propagate_precision.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_gelu.hpp>
Expand Down Expand Up @@ -191,6 +192,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
rewrite_reduce{},
rewrite_low_precision{},
dead_code_elimination{},
propagate_precision{},
dead_code_elimination{},
optimize_module{},
fuse_pointwise_reduce{},
enable_pass(enabled(MIGRAPHX_ENABLE_SPLIT_REDUCE{}), split_reduce{}),
Expand Down
Loading
Loading