-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pass to convert Uint8 to int8 across operators #2826
Changes from 31 commits
e757f4a
6c02b5a
6eef2ca
182fc60
92316f8
669c850
ce107f3
b4f87d1
32d3c54
6bfcf22
c52a116
0d6094c
4346acc
ce59bd1
f9f515d
5c71546
f9ef07e
f9904e7
d166547
8618a97
60d8671
e8349fd
98a646b
005fa53
22d4dd6
bf3f67e
b05bd34
e146863
fa15eac
7472d76
f45f0dc
b5b98a1
1034c47
6492bf5
2470273
6b94b6d
41ab319
91fc540
2373daa
f5ef46e
fc2c4e8
1fbe2f2
6dc8ef1
eebca5c
2887017
de17954
2cde7cc
ab77c68
0b27323
2654eae
637c22a
63952f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ struct dot | |
std::string name() const { return "dot"; } | ||
shape compute_shape(std::vector<shape> inputs) const | ||
{ | ||
check_shapes{inputs, *this, true}.same_type().same_ndims().has(2); | ||
check_shapes{inputs, *this, true}.same_ndims().has(2); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think they both need to be same for GPU backend atleast. For the Ref we can allow both to be different types, but if we are going to allow it, we better test it. Can you explain why you made this change ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was due to be using add_instruction and replace_instruction as part of the pass. I've used inserts and moves instead to get the desired result for the 1-d and 2d cases now, and it appears that our quantizelinear op was incorrect as only the first two args (data, scale) need be the same time and zero point must be assigned to the desired output type as zero point and output match. My recent changes should have fixed this as well as the quantizelinear op |
||
const shape& a = inputs.at(0); | ||
const shape& b = inputs.at(1); | ||
auto t = a.type(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* The MIT License (MIT) | ||
* | ||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in | ||
* all copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
* THE SOFTWARE. | ||
*/ | ||
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYNAMICQUANTIZELINEAR_HPP | ||
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYNAMICQUANTIZELINEAR_HPP | ||
|
||
#include <string> | ||
#include <migraphx/config.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
||
struct module; | ||
|
||
/** | ||
* Replaces uint8 ops for int8 quantization of min/max/convert outside of default | ||
*/ | ||
struct MIGRAPHX_EXPORT simplify_dynamicquantizelinear | ||
{ | ||
std::string name() const { return "simplify_dynamicquantizelinear"; } | ||
void apply(module& m) const; | ||
}; | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
/* | ||
* The MIT License (MIT) | ||
* | ||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in | ||
* all copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
* THE SOFTWARE. | ||
*/ | ||
|
||
#include <migraphx/simplify_dynamicquantizelinear.hpp> | ||
#include <migraphx/instruction.hpp> | ||
#include <migraphx/iterator_for.hpp> | ||
#include <migraphx/make_op.hpp> | ||
#include <migraphx/program.hpp> | ||
#include <migraphx/shape.hpp> | ||
#include <migraphx/matcher.hpp> | ||
#include <migraphx/dead_code_elimination.hpp> | ||
#include <migraphx/pass_manager.hpp> | ||
#include <migraphx/register_op.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
||
static std::unordered_set<std::string> get_quantizable_op_names() | ||
{ | ||
static std::unordered_set<std::string> s = {"convolution", "dot"}; | ||
return s; | ||
} | ||
|
||
/* | ||
* Dynamicquantizelinear by default adds uint8_t typed zero point into a quantize linear | ||
* which needs to converted to int8 in order to avoid uint8 x int8 operations or uint8 operations | ||
* from occuring on the backend as this isn't supported by MLIR nor how we simplify our quantizable | ||
* ops. | ||
*/ | ||
struct match_find_dynamicquantizelinear_convert_int8_zp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nives-vukovic this could be helpful for your problem. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its not a complete solution. Ive got changes built on this PR right now testing this out. It looks like the encoder models that are prequantized also have quant_dot/quant_convolution already added after the fact too so I'll need to add a check as well for them. This will be in a separate PR linked to @nives-vukovic issue. I've got one of the encoders reading in correctly, its just failing on compile now after I've added this updated pass. |
||
{ | ||
auto matcher() const | ||
{ | ||
return match::name(get_quantizable_op_names())( | ||
match::any_arg(0, 1)(skip_broadcast_squeeze(match::name("quantizelinear")( | ||
match::arg(0)(skip_broadcasts(match::any())), | ||
match::arg(2)(skip_broadcasts( | ||
match::name("convert")( | ||
match::has_type(migraphx::shape::uint8_type), | ||
match::arg(0)(match::name("nearbyint")( | ||
match::arg(0)(match::name("clip").bind("saturate"))) | ||
.bind("round"))) | ||
.bind("convert"))))))); | ||
} | ||
|
||
void apply(module& m, const match::matcher_result& r) const | ||
{ | ||
/* Need to modify the uint8 min/max range as well as final convert to convert to int8 */ | ||
auto convert_op = r.instructions["convert"]; | ||
// Ops to get q_min/q_max quickly | ||
auto round_op = r.instructions["round"]; | ||
auto saturate_op = r.instructions["saturate"]; | ||
auto q_min = saturate_op->inputs().at(1); | ||
auto q_max = saturate_op->inputs().at(2); | ||
TedThemistokleous marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// get new desired range defined by int8_t | ||
const auto x_min = std::numeric_limits<int8_t>::min(); | ||
const auto x_max = std::numeric_limits<int8_t>::max(); | ||
|
||
// Replace min/max of uint8 with min/max of int8 - q_range is identical so doesn't need to | ||
// be modified. Need to replace other ops which also take uint8 values first. | ||
auto x_type = q_min->get_shape().type(); | ||
auto q_min_int8 = m.add_literal( | ||
migraphx::literal{migraphx::shape{x_type, q_min->get_shape().lens()}, {x_min}}); | ||
auto q_max_int8 = m.add_literal( | ||
migraphx::literal{migraphx::shape{x_type, q_max->get_shape().lens()}, {x_max}}); | ||
|
||
m.replace_instruction(q_min, q_min_int8); | ||
m.replace_instruction(q_max, q_max_int8); | ||
m.replace_instruction( | ||
convert_op, | ||
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), | ||
round_op); | ||
} | ||
}; | ||
|
||
void simplify_dynamicquantizelinear::apply(module& m) const | ||
{ | ||
match::find_matches(m, match_find_dynamicquantizelinear_convert_int8_zp{}); | ||
migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); | ||
} | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just put this matcher in the .cpp file. as its specific to the pass.