Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop QDQ around more nodes #21376

Merged
merged 45 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7d38ea4
Remove QDQ nodes around Flatten
mcollinswisc Jun 26, 2024
b3b506b
Add more operators for which QDQ can be removed
mcollinswisc Jun 26, 2024
2160e44
Keep QDQ nodes w/ nonpositive scale around MaxPool
mcollinswisc Jun 26, 2024
61c5d84
Unit test on QDQ w/ nonpositive scale around MaxPool
mcollinswisc Jun 26, 2024
e3176e6
Merge branch 'main' into qdq_optim
mcollinswisc Jun 26, 2024
e84b3e6
Merge branch 'qdq_optim_nonpositive_scale' into qdq_optim
mcollinswisc Jun 26, 2024
ce1ee8a
Unit test on removing QDQ around Expand
mcollinswisc Jun 28, 2024
73cc212
Add selector to remove QDQ nodes around Min, Max, and Abs
mcollinswisc Jun 28, 2024
3b94f98
Change formatting according to clangformat
mcollinswisc Jul 8, 2024
daf808e
Unit test that QDQ nodes are remove around Tile
mcollinswisc Jul 8, 2024
641747f
Merge branch 'qdq_optim_nonpositive_scale' into qdq_optim
mcollinswisc Jul 8, 2024
7e5db77
Merge branch 'main' into qdq_optim
mcollinswisc Jul 8, 2024
58e525c
Merge branch 'main' into qdq_optim_nonpositive_scale
mcollinswisc Jul 8, 2024
68def1e
Switch to std::fileystem::path in IsQOrDQScalePositiveConstantScalar
mcollinswisc Jul 8, 2024
342034d
Merge branch 'qdq_optim_nonpositive_scale' into qdq_optim
mcollinswisc Jul 8, 2024
cf10b50
Remove SpaceToDepth and SpaceToDepth from Drop QDQ optimization
mcollinswisc Jul 8, 2024
939f240
Fix grammar in comment
mcollinswisc Jul 10, 2024
452bdd1
Unit test that QDQ is dropped around Slice
mcollinswisc Jul 10, 2024
a67d3fb
Unit test on removing QDQ around GatherElements
mcollinswisc Jul 11, 2024
c7a50da
Apply lintrunner
mcollinswisc Jul 11, 2024
82fdcc4
Drop QDQ around ReduceMin & ReduceMax, not Min & Max
mcollinswisc Jul 12, 2024
9b3bf09
Disallow 16bit for ReduceMin & ReduceMax
mcollinswisc Jul 12, 2024
872f983
Unit test on dropping QDQ from around ReduceMin/Max
mcollinswisc Jul 12, 2024
5bf7d84
Fix comment in ReduceExtremumDropQDQ test case
mcollinswisc Jul 16, 2024
28824d6
Remove selector to drop QDQ around Abs
mcollinswisc Jul 16, 2024
83d85f3
Merge branch 'main' into qdq_optim
mcollinswisc Jul 16, 2024
71253a8
Reformatting from lintrunner
mcollinswisc Jul 16, 2024
191dc15
Fix comment grammar according to review suggestion
mcollinswisc Jul 25, 2024
023bbaf
Register drop_action_no_int16_nor_nonpositive_scale in minimal build
mcollinswisc Jul 25, 2024
c287525
Continue line to keep it under 120 chars
mcollinswisc Jul 25, 2024
e34258d
Change name "no nonpositive scale" to "and positive scale"
mcollinswisc Jul 25, 2024
a880a21
Merge branch 'main' into qdq_optim_nonpositive_scale
mcollinswisc Jul 25, 2024
b0753e4
More reformatting to keep lines under 120 chars
mcollinswisc Jul 25, 2024
102e952
Merge branch 'qdq_optim_nonpositive_scale' into qdq_optim
mcollinswisc Jul 26, 2024
8144f66
Merge branch 'main' into qdq_optim
mcollinswisc Jul 30, 2024
610a6a9
Undo spacing change in Gemm test
mcollinswisc Jul 31, 2024
d63854b
Delete drop_action_no_nonpositive_scale_name
mcollinswisc Jul 31, 2024
848f701
Alphabetize operator names included in base selector
mcollinswisc Jul 31, 2024
b1de950
Alphabetize operators in no 16-bit & positive selector
mcollinswisc Jul 31, 2024
c193e0e
[Merge branch 'main' into qdq_optim
mcollinswisc Jul 31, 2024
fb91e41
Move some comments to line before
mcollinswisc Aug 1, 2024
49edace
Merge branch 'main' into qdq_optim
mcollinswisc Aug 1, 2024
85010fe
Change spacing around comments according to clang-format/lintrunner
mcollinswisc Aug 5, 2024
b7fd7b7
Merge branch 'main' into qdq_optim
mcollinswisc Aug 15, 2024
e5638e5
Set /bigobj for qdq_transformer_test.cc
mcollinswisc Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -887,10 +887,12 @@ if (MSVC)
target_compile_options(onnxruntime_test_all PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4244>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd4244>")

# Avoid this compile error in graph_transform_test.cc:
# Avoid this compile error in graph_transform_test.cc and qdq_transformer_test.cc:
# fatal error C1128: number of sections exceeded object file format limit: compile with /bigobj
set_property(SOURCE "${TEST_SRC_DIR}/optimizer/graph_transform_test.cc"
APPEND PROPERTY COMPILE_OPTIONS "/bigobj")
set_property(SOURCE "${TEST_SRC_DIR}/optimizer/qdq_transformer_test.cc"
APPEND PROPERTY COMPILE_OPTIONS "/bigobj")
else()
target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses")
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,25 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr<NodeSelector> selector_no_16bit_and_positive_scale =
std::make_unique<QDQ::DropQDQNodesSelector>(false, true, false);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_and_positive_scale_name,
{{"MaxPool", {12}}},
{{"MaxPool", {12}},
{"ReduceMax", {}},
{"ReduceMin", {}}},
std::move(selector_no_16bit_and_positive_scale),
std::move(drop_action_no_int16_and_positive_scale));

std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::DropQDQNodesSelector>(true);
// DepthToSpace and SpaceToDepth not included because there are no integer implementations.
// https://github.com/microsoft/onnxruntime/issues/21287
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name,
{{"Gather", {}},
{{"Expand", {}},
{"Flatten", {}},
{"Gather", {}},
{"GatherElements", {}},
{"Reshape", {}},
{"Transpose", {}},
{"Slice", {}},
{"Squeeze", {}},
{"Tile", {}},
{"Transpose", {}},
{"Unsqueeze", {}}},
std::move(selector),
std::move(drop_action));
Expand Down
291 changes: 291 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,297 @@ TEST(QDQTransformerTests, UnsqueezeDropQDQ) {
RunSqueezeUnsqueezeDropQDQTestCase<uint16_t>("Unsqueeze", {1, 3, 2, 2}, {0}, false, 21);
}

// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> Flatten -> Q.
template <typename QuantType>
static void RunFlattenDropQDQTestCase(const std::vector<int64_t>& input_shape,
int64_t axis = 1,
bool use_contrib_qdq = false,
int opset = 21) {
auto build_test_case = [input_shape, axis, use_contrib_qdq](ModelTestBuilder& builder) {
constexpr QuantType qmin = std::numeric_limits<QuantType>::min();
constexpr QuantType qmax = std::numeric_limits<QuantType>::max();

auto* input_arg = builder.MakeInput<QuantType>(input_shape, qmin, qmax);
auto* output_arg = builder.MakeOutput();
QuantType zero_point = 1 + (qmax + qmin) / 2;

auto* input_arg_dq = builder.MakeIntermediate();
auto* flatten_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<QuantType>(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq);
Node& flatten_node = builder.AddNode("Flatten", {input_arg_dq}, {flatten_output});
flatten_node.AddAttribute("axis", axis);

// add Q
builder.AddQuantizeLinearNode<QuantType>(flatten_output, .003f, zero_point, output_arg, use_contrib_qdq);
};

auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count["Flatten"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset);
}

// Checks that Q/DQ nodes are dropped from DQ -> Reshape -> Q. Uses 8-bit and 16-bit Q/DQ ops.
TEST(QDQTransformerTests, FlattenDropQDQ) {
for (int64_t axis : {0, 1, 3}) {
RunFlattenDropQDQTestCase<int8_t>({1, 3, 2, 2}, axis);
RunFlattenDropQDQTestCase<int8_t>({1, 3, 2, 2}, axis, true, 13); // Use com.microsoft QDQ ops
RunFlattenDropQDQTestCase<int16_t>({1, 3, 2, 2}, axis, true, 13); // Use int16 com.microsoft QDQ ops
RunFlattenDropQDQTestCase<uint16_t>({1, 3, 2, 2}, axis, true, 13); // Use int16 com.microsoft QDQ ops
RunFlattenDropQDQTestCase<int16_t>({1, 3, 2, 2}, axis, false); // Use int16 ONNX QDQ ops
RunFlattenDropQDQTestCase<uint16_t>({1, 3, 2, 2}, axis, false); // Use int16 ONNX QDQ ops
}
}

// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> Expand -> Q.
template <typename QuantType>
static void RunExpandDropQDQTestCase(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& expanded_shape,
bool use_contrib_qdq = false,
int opset = 21) {
auto build_test_case = [input_shape, expanded_shape, use_contrib_qdq](ModelTestBuilder& builder) {
constexpr QuantType qmin = std::numeric_limits<QuantType>::min();
constexpr QuantType qmax = std::numeric_limits<QuantType>::max();

auto* input_arg = builder.MakeInput<QuantType>(input_shape, qmin, qmax);
auto* output_arg = builder.MakeOutput();
QuantType zero_point = 1 + (qmax + qmin) / 2;

auto* input_arg_dq = builder.MakeIntermediate();
auto* expanded_shape_arg = builder.Make1DInitializer(expanded_shape);
auto* expand_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<QuantType>(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq);
builder.AddNode("Expand", {input_arg_dq, expanded_shape_arg}, {expand_output});

// add Q
builder.AddQuantizeLinearNode<QuantType>(expand_output, .003f, zero_point, output_arg, use_contrib_qdq);
};

auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count["Expand"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset);
}

// Checks that Q/DQ nodes are dropped from DQ -> Expand -> Q. Uses 8-bit and 16-bit Q/DQ ops.
TEST(QDQTransformerTests, ExpandDropQDQ) {
RunExpandDropQDQTestCase<int8_t>({1, 3, 1, 1}, {1, 3, 7, 13});
RunExpandDropQDQTestCase<int8_t>({1, 3, 1, 1}, {1, 3, 7, 13}, true, 13); // Use com.microsoft QDQ ops
RunExpandDropQDQTestCase<int16_t>({1, 3, 1, 1}, {1, 3, 7, 13}, true, 13); // Use int16 com.microsoft QDQ ops
RunExpandDropQDQTestCase<uint16_t>({1, 3, 1, 1}, {1, 3, 7, 13}, true, 13); // Use int16 com.microsoft QDQ ops
RunExpandDropQDQTestCase<int16_t>({1, 3, 1, 1}, {1, 3, 7, 13}, false); // Use int16 ONNX QDQ ops
RunExpandDropQDQTestCase<uint16_t>({1, 3, 1, 1}, {1, 3, 7, 13}, false); // Use int16 ONNX QDQ ops
}

// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> Tile -> Q.
template <typename QuantType>
static void RunTileDropQDQTestCase(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& repeats,
bool use_contrib_qdq = false,
int opset = 21) {
auto build_test_case = [input_shape, repeats, use_contrib_qdq](ModelTestBuilder& builder) {
constexpr QuantType qmin = std::numeric_limits<QuantType>::min();
constexpr QuantType qmax = std::numeric_limits<QuantType>::max();

auto* input_arg = builder.MakeInput<QuantType>(input_shape, qmin, qmax);
auto* output_arg = builder.MakeOutput();
QuantType zero_point = 1 + (qmax + qmin) / 2;

auto* input_arg_dq = builder.MakeIntermediate();
auto* repeats_arg = builder.Make1DInitializer(repeats);
auto* tile_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<QuantType>(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq);
builder.AddNode("Tile", {input_arg_dq, repeats_arg}, {tile_output});

// add Q
builder.AddQuantizeLinearNode<QuantType>(tile_output, .003f, zero_point, output_arg, use_contrib_qdq);
};

auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count["Tile"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset);
}

// Checks that Q/DQ nodes are dropped from DQ -> Tile -> Q. Uses 8-bit and 16-bit Q/DQ ops.
TEST(QDQTransformerTests, TileDropQDQ) {
RunTileDropQDQTestCase<int8_t>({1, 3, 2, 2}, {1, 1, 3, 3});
RunTileDropQDQTestCase<int8_t>({1, 3, 2, 2}, {1, 1, 3, 3}, true, 13); // Use com.microsoft QDQ ops
RunTileDropQDQTestCase<int16_t>({1, 3, 2, 2}, {1, 1, 3, 3}, true, 13); // Use int16 com.microsoft QDQ ops
RunTileDropQDQTestCase<uint16_t>({1, 3, 2, 2}, {1, 1, 3, 3}, true, 13); // Use int16 com.microsoft QDQ ops
RunTileDropQDQTestCase<int16_t>({1, 3, 2, 2}, {1, 1, 3, 3}, false); // Use int16 ONNX QDQ ops
RunTileDropQDQTestCase<uint16_t>({1, 3, 2, 2}, {1, 1, 3, 3}, false); // Use int16 ONNX QDQ ops
}

// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> Slice -> Q.
template <typename QuantType>
static void RunSliceDropQDQTestCase(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
bool use_contrib_qdq = false,
int opset = 21) {
auto build_test_case = [input_shape, starts, ends, use_contrib_qdq](ModelTestBuilder& builder) {
constexpr QuantType qmin = std::numeric_limits<QuantType>::min();
constexpr QuantType qmax = std::numeric_limits<QuantType>::max();

auto* input_arg = builder.MakeInput<QuantType>(input_shape, qmin, qmax);
auto* output_arg = builder.MakeOutput();
QuantType zero_point = 1 + (qmax + qmin) / 2;

auto* input_arg_dq = builder.MakeIntermediate();
auto* starts_arg = builder.Make1DInitializer(starts);
auto* ends_arg = builder.Make1DInitializer(ends);
auto* slice_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<QuantType>(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq);
builder.AddNode("Slice", {input_arg_dq, starts_arg, ends_arg}, {slice_output});

// add Q
builder.AddQuantizeLinearNode<QuantType>(slice_output, .003f, zero_point, output_arg, use_contrib_qdq);
};

auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count["Slice"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset);
}

// Checks that Q/DQ nodes are dropped from DQ -> Slice -> Q. Uses 8-bit and 16-bit Q/DQ ops.
TEST(QDQTransformerTests, SliceDropQDQ) {
RunSliceDropQDQTestCase<int8_t>({1, 3, 5, 5}, {0, 1, 1, 1}, {1, 3, 4, 4});
RunSliceDropQDQTestCase<int8_t>({1, 3, 5, 5}, {0, 1, 1, 1}, {1, 3, 4, 4}, true, 13); // Use com.microsoft QDQ ops
// Use int16 com.microsoft QDQ ops
RunSliceDropQDQTestCase<int16_t>({1, 3, 5, 5}, {0, 1, 1, 1}, {1, 3, 4, 4}, true, 13);
// Use int16 com.microsoft QDQ ops
RunSliceDropQDQTestCase<uint16_t>({1, 3, 5, 5}, {0, 1, 1, 1}, {1, 3, 4, 4}, true, 13);
RunSliceDropQDQTestCase<int16_t>({1, 3, 5, 5}, {0, 1, 1, 1}, {1, 3, 4, 4}, false); // Use int16 ONNX QDQ ops
RunSliceDropQDQTestCase<uint16_t>({1, 3, 5, 5}, {0, 1, 1, 1}, {1, 3, 4, 4}, false); // Use int16 ONNX QDQ ops
}

// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> GatherElements -> Q.
template <typename QuantType>
static void RunGatherElementsDropQDQTestCase(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& indices_shape,
const std::vector<int64_t>& indices_data,
bool use_contrib_qdq = false,
int opset = 21) {
auto build_test_case = [input_shape, indices_shape, indices_data, use_contrib_qdq](ModelTestBuilder& builder) {
constexpr QuantType qmin = std::numeric_limits<QuantType>::min();
constexpr QuantType qmax = std::numeric_limits<QuantType>::max();

auto* input_arg = builder.MakeInput<QuantType>(input_shape, qmin, qmax);
auto* indices_arg = builder.MakeInitializer<int64_t>(indices_shape, indices_data);
auto* output_arg = builder.MakeOutput();
QuantType zero_point = 1 + (qmax + qmin) / 2;

auto* input_arg_dq = builder.MakeIntermediate();
auto* gather_elements_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<QuantType>(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq);
builder.AddNode("GatherElements", {input_arg_dq, indices_arg}, {gather_elements_output});

// add Q
builder.AddQuantizeLinearNode<QuantType>(gather_elements_output, .003f, zero_point, output_arg, use_contrib_qdq);
};

auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count["GatherElements"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset);
}

// Checks that Q/DQ nodes are dropped from DQ -> GatherElements -> Q. Uses 8-bit and 16-bit Q/DQ ops.
TEST(QDQTransformerTests, GatherElementsDropQDQ) {
RunGatherElementsDropQDQTestCase<int8_t>({3, 3}, {2, 3}, {1, 2, 0, 2, 0, 0});
// Use com.microsoft QDQ ops
RunGatherElementsDropQDQTestCase<int8_t>({3, 3}, {2, 3}, {1, 2, 0, 2, 0, 0}, true, 13);
// Use int16 com.microsoft QDQ ops
RunGatherElementsDropQDQTestCase<int16_t>({3, 3}, {2, 3}, {1, 2, 0, 2, 0, 0}, true, 13);
// Use int16 com.microsoft QDQ ops
RunGatherElementsDropQDQTestCase<uint16_t>({3, 3}, {2, 3}, {1, 2, 0, 2, 0, 0}, true, 13);
RunGatherElementsDropQDQTestCase<int16_t>({3, 3}, {2, 3}, {1, 2, 0, 2, 0, 0}, false); // Use int16 ONNX QDQ ops
RunGatherElementsDropQDQTestCase<uint16_t>({3, 3}, {2, 3}, {1, 2, 0, 2, 0, 0}, false); // Use int16 ONNX QDQ ops
}

// Runs a test case whether Q/DQ nodes are dropped from DQ -> Reduce(Min|Max) -> Q.
template <typename QuantType>
static void RunReduceExtremumDropQDQTestCase(const std::string& op_type,
const std::vector<int64_t>& input_shape,
float qscale,
bool expect_drop_qdq,
bool use_contrib_qdq = false,
int opset = 21) {
auto build_test_case = [op_type, input_shape, qscale, use_contrib_qdq](ModelTestBuilder& builder) {
constexpr QuantType qmin = std::numeric_limits<QuantType>::min();
constexpr QuantType qmax = std::numeric_limits<QuantType>::max();

auto* input_arg = builder.MakeInput<QuantType>(input_shape, qmin, qmax);
auto* output_arg = builder.MakeOutput();
QuantType zero_point = 1 + (qmax + qmin) / 2;

auto* input_arg_dq = builder.MakeIntermediate();
auto* reduce_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<QuantType>(input_arg, qscale, zero_point, input_arg_dq, use_contrib_qdq);
builder.AddNode(op_type, {input_arg_dq}, {reduce_output});

// add Q
builder.AddQuantizeLinearNode<QuantType>(reduce_output, qscale, zero_point, output_arg, use_contrib_qdq);
};

auto check_graph = [op_type, expect_drop_qdq, use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count[op_type], 1);
if (expect_drop_qdq) {
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
} else {
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1);
}
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset);
}

// Checks whether Q/DQ nodes are dropped from DQ -> Reduce(Min|Max) -> Q. Uses 8-bit and 16-bit Q/DQ ops.
TEST(QDQTransformerTests, ReduceExtremumDropQDQ) {
// Check that Q/DQ nodes are dropped for positive scale
RunReduceExtremumDropQDQTestCase<int8_t>("ReduceMin", {3, 3}, 0.003f, true);
RunReduceExtremumDropQDQTestCase<int8_t>("ReduceMin", {3, 3}, 0.003f, true, true, 13); // Use com.microsoft QDQ ops
RunReduceExtremumDropQDQTestCase<int8_t>("ReduceMax", {3, 3}, 0.003f, true);
RunReduceExtremumDropQDQTestCase<int8_t>("ReduceMax", {3, 3}, 0.003f, true, true, 13); // Use com.microsoft QDQ ops

// Check that Q/DQ nodes are *not* dropped for negative scale
RunReduceExtremumDropQDQTestCase<int8_t>("ReduceMin", {3, 3}, -0.003f, false);
RunReduceExtremumDropQDQTestCase<int8_t>("ReduceMin", {3, 3}, -0.003f, false, true, 13); // Use com.microsoft QDQ ops
RunReduceExtremumDropQDQTestCase<int8_t>("ReduceMax", {3, 3}, -0.003f, false);
RunReduceExtremumDropQDQTestCase<int8_t>("ReduceMax", {3, 3}, -0.003f, false, true, 13); // Use com.microsoft QDQ ops
}

TEST(QDQTransformerTests, DoubleQDQ) {
constexpr uint8_t good_u8_1 = 80;
constexpr uint8_t good_u8_2 = 40;
Expand Down
Loading