Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pass to convert Uint8 to int8 across operators #2826

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
e757f4a
Avoid uint8 being added for dynamicQuantizeLinear
Feb 23, 2024
6c02b5a
Update dynamicquantizelinear tests
Feb 23, 2024
6eef2ca
Revert "Update dynamicquantizelinear tests"
Feb 27, 2024
182fc60
Revert "Avoid uint8 being added for dynamicQuantizeLinear"
Feb 27, 2024
92316f8
Add has_type matcher
Feb 28, 2024
669c850
backup of pass for modifying dybamicquantizelinear
Feb 28, 2024
ce107f3
Fix format
Feb 28, 2024
b4f87d1
Split out dynamic quantize linear as as seperate pass to simplify_qdq…
Feb 28, 2024
32d3c54
Cleanup move from has_type matcher
Feb 28, 2024
6bfcf22
Fix license
Feb 28, 2024
c52a116
Add pass to verify test
Feb 28, 2024
0d6094c
Cleanup text in simplify pass
Feb 28, 2024
4346acc
Add reduce axes in parse_dynamicquantizelinear
Feb 28, 2024
ce59bd1
Fix bug found in parse_dynamicquantizelinear
Feb 29, 2024
f9f515d
Fix licensing in updated dynamicquantizelinear related files
Feb 29, 2024
5c71546
Update CPU target with simplify_dynamicquantizelinear
Feb 29, 2024
f9ef07e
Format
Feb 29, 2024
f9904e7
rename x_max/x_min to be more clear
Feb 29, 2024
d166547
Add another verify test for 2d case
Feb 29, 2024
8618a97
Remove input flatten, handle serial min/max ops
Feb 29, 2024
60d8671
Make matcher more robust to target series of ops used in dynamicquant…
Mar 1, 2024
e8349fd
fix format for update matcher
Mar 1, 2024
98a646b
Fix parse tests
Mar 1, 2024
005fa53
Update matcher to remove the use of m_t
Mar 1, 2024
22d4dd6
Rename len_vec to axes for reduce operators
Mar 1, 2024
bf3f67e
Set simplify path to only target supported quantize ops
Mar 2, 2024
b05bd34
Create skip_broadcast_squeeze to handle 1d case
Mar 5, 2024
e146863
Fix format
Mar 5, 2024
fa15eac
Add same_type() back into dot operator. Silence 1d test for now
Mar 5, 2024
7472d76
Fix sanitizer error with int8 overflow in unit test. Remove same_type…
Mar 6, 2024
f45f0dc
Fix pyflakes errors
Mar 6, 2024
b5b98a1
Handle the input to the quantizelinear as part of pass
Mar 14, 2024
1034c47
Remove debug print
Mar 14, 2024
6492bf5
Merge branch 'develop' into fix_dynamic_int8_quant
TedThemistokleous Mar 14, 2024
2470273
Fix tidy errors with readability
Mar 14, 2024
6b94b6d
Adjust matcher to grab next op to our quantize ops
Mar 15, 2024
41ab319
Fix clang error - match op after quant_op
Mar 15, 2024
91fc540
Merge branch 'develop' into fix_dynamic_int8_quant
TedThemistokleous Mar 15, 2024
2373daa
Format / pyflakes fix
Mar 15, 2024
f5ef46e
Fixes to parse DynamicQuantizeLinear
Mar 18, 2024
fc2c4e8
Fix format
Mar 18, 2024
1fbe2f2
Merge branch 'fix_parse_dynamicquantizelinear' into fix_dynamic_int8_…
TedThemistokleous Mar 18, 2024
6dc8ef1
Merge branch 'develop' into fix_parse_dynamicquantizelinear
TedThemistokleous Mar 18, 2024
eebca5c
Merge branch 'fix_parse_dynamicquantizelinear' into fix_dynamic_int8_…
TedThemistokleous Mar 18, 2024
2887017
Fix format
Mar 15, 2024
de17954
Remove extra literal that slipped in when resolving conflict
Mar 18, 2024
2cde7cc
Update tests to use quant_dot to capture correct int8 output. Update …
Mar 18, 2024
ab77c68
Merge branch 'develop' into fix_dynamic_int8_quant
TedThemistokleous Apr 9, 2024
0b27323
Merge branch 'develop' into fix_dynamic_int8_quant
TedThemistokleous Apr 29, 2024
2654eae
Merge branch 'develop' into fix_dynamic_int8_quant
causten May 28, 2024
637c22a
Merge branch 'develop' into fix_dynamic_int8_quant
TedThemistokleous Aug 3, 2024
63952f2
Merge branch 'develop' into fix_dynamic_int8_quant
causten Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ add_library(migraphx
replace_allocate.cpp
rewrite_reduce.cpp
simplify_qdq.cpp
simplify_dynamicquantizelinear.cpp
sqlite.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp
Expand Down
12 changes: 12 additions & 0 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,12 @@ inline auto var(std::string s)
});
}

inline auto has_type(shape::type_t t)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) { return ins->get_shape().type() == t; });
}

inline auto name(std::string s)
{
return make_basic_pred_matcher(
Expand Down Expand Up @@ -858,6 +864,12 @@ auto skip_broadcasts_converts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}

template <class... Ms>
auto skip_broadcast_squeeze(Ms... ms)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just put this matcher in the .cpp file. as its specific to the pass.

{
return skip(name("broadcast", "multibroadcast", "contiguous", "squeeze", "unsqueeze"))(ms...);
}

template <class F>
inline auto literal_value_checker(F f)
{
Expand Down
2 changes: 1 addition & 1 deletion src/include/migraphx/op/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct dot
std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.same_type().same_ndims().has(2);
check_shapes{inputs, *this, true}.same_ndims().has(2);
Copy link
Member

@umangyadav umangyadav Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they both need to be same for GPU backend atleast. For the Ref we can allow both to be different types, but if we are going to allow it, we better test it. Can you explain why you made this change ?

Copy link
Collaborator Author

@TedThemistokleous TedThemistokleous Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was due to be using add_instruction and replace_instruction as part of the pass. I've used inserts and moves instead to get the desired result for the 1-d and 2d cases now, and it appears that our quantizelinear op was incorrect as only the first two args (data, scale) need be the same time and zero point must be assigned to the desired output type as zero point and output match.

My recent changes should have fixed this as well as the quantizelinear op

const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
Expand Down
47 changes: 47 additions & 0 deletions src/include/migraphx/simplify_dynamicquantizelinear.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYNAMICQUANTIZELINEAR_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYNAMICQUANTIZELINEAR_HPP

#include <string>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module;

/**
* Replaces uint8 ops for int8 quantization of min/max/convert outside of default
*/
struct MIGRAPHX_EXPORT simplify_dynamicquantizelinear
{
std::string name() const { return "simplify_dynamicquantizelinear"; }
void apply(module& m) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
45 changes: 22 additions & 23 deletions src/onnx/parse_dynamicquantizelinear.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -98,41 +98,40 @@ struct parse_dynamicquantizelinear : op_parser<parse_dynamicquantizelinear>
if(x_shape.dynamic())
MIGRAPHX_THROW("DYNAMICQUANTIZELINEAR: dynamic shapes are not supported");

auto x_reshaped =
(x_shape.lens().size() == 1)
? x
: info.add_instruction(
migraphx::make_op("reshape", {{"dims", {x_shape.elements()}}}), x);

auto lit_0 = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0}});
x_reshaped =
info.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, lit_0);

// 1. Computing y_scale
// Note: currently, DynamicQuantizeLinear only has uint8 quantization:
const auto x_max = std::numeric_limits<uint8_t>::max();
const auto x_min = std::numeric_limits<uint8_t>::min();

auto q_range =
info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_max - x_min}});
const auto type_max = std::numeric_limits<uint8_t>::max();
const auto type_min = std::numeric_limits<uint8_t>::min();
std::vector<size_t> axes(x_shape.lens().size());
std::iota(axes.begin(), axes.end(), 0);

// maximum(0, max(x))
auto max_x =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped);
auto reduce_max_x =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", axes}}), x);
auto max_x = info.add_common_op("max", lit_0, reduce_max_x);

// minimum(0, min(x))
auto min_x =
info.add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped);
auto reduce_min_x =
info.add_instruction(migraphx::make_op("reduce_min", {{"axes", axes}}), x);
auto min_x = info.add_common_op("min", lit_0, reduce_min_x);

auto q_range = info.add_literal(migraphx::literal{
migraphx::shape{x_type, max_x->get_shape().lens()}, {type_max - type_min}});

auto q_min = info.add_literal(
migraphx::literal{migraphx::shape{x_type, min_x->get_shape().lens()}, {type_min}});
auto q_max = info.add_literal(
migraphx::literal{migraphx::shape{x_type, max_x->get_shape().lens()}, {type_max}});

// y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin)
auto sub0 = info.add_common_op("sub", max_x, min_x);
auto y_scale = info.add_common_op("div", sub0, q_range);

// 2. Computing y_zero_point
// intermediate_zero_point = qmin - min(x) / y_scale
auto q_min = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_min}});
auto q_max = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_max}});
auto sub1 = info.add_common_op("sub", q_min, min_x);
auto interm_zp = info.add_common_op("div", sub1, y_scale);
auto div1 = info.add_common_op("div", min_x, y_scale);
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
auto interm_zp = info.add_common_op("sub", q_min, div1);
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
// y_zero_point = cast(round(saturate(itermediate_zero_point)))
auto saturate = info.add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max);
auto round = info.add_instruction(migraphx::make_op("nearbyint"), saturate);
Expand Down
105 changes: 105 additions & 0 deletions src/simplify_dynamicquantizelinear.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/simplify_dynamicquantizelinear.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

static std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
}

/*
* Dynamicquantizelinear by default adds uint8_t typed zero point into a quantize linear
* which needs to converted to int8 in order to avoid uint8 x int8 operations or uint8 operations
* from occuring on the backend as this isn't supported by MLIR nor how we simplify our quantizable
* ops.
*/
struct match_find_dynamicquantizelinear_convert_int8_zp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nives-vukovic this could be helpful for your problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not a complete solution. Ive got changes built on this PR right now testing this out. It looks like the encoder models that are prequantized also have quant_dot/quant_convolution already added after the fact too so I'll need to add a check as well for them.

This will be in a separate PR linked to @nives-vukovic issue. I've got one of the encoders reading in correctly, its just failing on compile now after I've added this updated pass.

{
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::any_arg(0, 1)(skip_broadcast_squeeze(match::name("quantizelinear")(
match::arg(0)(skip_broadcasts(match::any())),
match::arg(2)(skip_broadcasts(
match::name("convert")(
match::has_type(migraphx::shape::uint8_type),
match::arg(0)(match::name("nearbyint")(
match::arg(0)(match::name("clip").bind("saturate")))
.bind("round")))
.bind("convert")))))));
}

void apply(module& m, const match::matcher_result& r) const
{
/* Need to modify the uint8 min/max range as well as final convert to convert to int8 */
auto convert_op = r.instructions["convert"];
// Ops to get q_min/q_max quickly
auto round_op = r.instructions["round"];
auto saturate_op = r.instructions["saturate"];
auto q_min = saturate_op->inputs().at(1);
auto q_max = saturate_op->inputs().at(2);
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved

// get new desired range defined by int8_t
const auto x_min = std::numeric_limits<int8_t>::min();
const auto x_max = std::numeric_limits<int8_t>::max();

// Replace min/max of uint8 with min/max of int8 - q_range is identical so doesn't need to
// be modified. Need to replace other ops which also take uint8 values first.
auto x_type = q_min->get_shape().type();
auto q_min_int8 = m.add_literal(
migraphx::literal{migraphx::shape{x_type, q_min->get_shape().lens()}, {x_min}});
auto q_max_int8 = m.add_literal(
migraphx::literal{migraphx::shape{x_type, q_max->get_shape().lens()}, {x_max}});

m.replace_instruction(q_min, q_min_int8);
m.replace_instruction(q_max, q_max_int8);
m.replace_instruction(
convert_op,
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}),
round_op);
}
};

void simplify_dynamicquantizelinear::apply(module& m) const
{
match::find_matches(m, match_find_dynamicquantizelinear_convert_int8_zp{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
2 changes: 1 addition & 1 deletion src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ auto skip_post_dq_ops(Ms... ms)
match::name("broadcast", "multibroadcast", "contiguous", "transpose", "reshape"))(ms...);
}

std::unordered_set<std::string> get_quantizable_op_names()
static std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
Expand Down
2 changes: 2 additions & 0 deletions src/targets/cpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_dynamicquantizelinear.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/cpu/fuse_ops.hpp>
#include <migraphx/cpu/write_literals.hpp>
Expand All @@ -68,6 +69,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{},
simplify_dynamicquantizelinear{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include <migraphx/schedule.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_dynamicquantizelinear.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/gpu/allocation_model.hpp>
Expand Down Expand Up @@ -127,6 +128,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
normalize_ops{},
dead_code_elimination{},
simplify_dynamicquantizelinear{},
simplify_qdq{},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
Expand Down
Loading
Loading