Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep QDQ nodes w/ nonpositive scale around MaxPool #21182

Merged
merged 13 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,41 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint(
return true;
}

bool IsQOrDQScalePositiveConstantScalar(
const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer,
const std::filesystem::path& model_path) {
auto q_or_dq_input_defs = q_or_dq_node.InputDefs();

ORT_ENFORCE(q_or_dq_input_defs.size() >= 2);

if (!optimizer_utils::IsScalar(*q_or_dq_input_defs[InputIndex::SCALE_ID])) {
return false;
}

const ONNX_NAMESPACE::TensorProto* q_or_dq_scale_tensor_proto =
get_const_initializer(q_or_dq_input_defs[InputIndex::SCALE_ID]->Name());
if (nullptr == q_or_dq_scale_tensor_proto) {
return false;
}

Initializer q_or_dq_scale(*q_or_dq_scale_tensor_proto, model_path);

switch (q_or_dq_scale.data_type()) {
case ONNX_NAMESPACE::TensorProto::FLOAT:
return q_or_dq_scale.data<float>()[0] > 0;

case ONNX_NAMESPACE::TensorProto::FLOAT16:
return q_or_dq_scale.data<MLFloat16>()[0] > 0;

case ONNX_NAMESPACE::TensorProto::BFLOAT16:
return q_or_dq_scale.data<BFloat16>()[0] > 0;

default:
assert(false);
return false;
}
}

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

bool MatchQNode(const Node& node) {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint(
const GetConstantInitializerFn& get_const_initializer,
bool& zero_point_exists);

// Checks that the y_scale/x_scale input to the QuantizeLinear/DequantizeLinear node is a positive scalar.
bool IsQOrDQScalePositiveConstantScalar(const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer,
const std::filesystem::path& model_path);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// Check Q node op type, version, and domain.
bool MatchQNode(const Node& node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
// 3 nodes. DQ, target, Q. Merge into target and remove DQ and Q.
const std::string drop_action_name{"drop"};
const std::string drop_action_no_int16_name{"drop_no_int16_support"};
const std::string drop_action_no_int16_and_positive_scale_name{"drop_no_int16_support_and_positive_scale"};
NTO::NodeLocation dq{NTO::NodeType::kInput, 0};

Check warning on line 39 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:39:2: "NTO" is a misspelling of "NOT"

Check warning on line 39 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:39:23: "NTO" is a misspelling of "NOT"
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};

Check warning on line 40 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:40:2: "NTO" is a misspelling of "NOT"

Check warning on line 40 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:40:22: "NTO" is a misspelling of "NOT"

// Move DQ input 0 to target input 0.
// Move Q output 0 to target output 0.
Expand All @@ -46,19 +47,32 @@

std::unique_ptr<Action> drop_action_no_int16 = std::make_unique<MergeIntoTargetFixed>(
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action_no_int16_and_positive_scale = std::make_unique<MergeIntoTargetFixed>(
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action = std::make_unique<MergeIntoTargetFixed>(std::move(moves));

#if !defined(ORT_MINIMAL_BUILD)
// Use a separate selector + action that disallows 16-bit types for MaxPool and Resize.
// Use separate selectors & actions for MaxPool and Resize.
//
// They disallow 16-bit types for MaxPool and Resize:
// int16 MaxPool is not supported by the ONNX specification.
// int16 Resize is not supported by the ORT implementation (although allowed by ONNX).
std::unique_ptr<NodeSelector> selector_disallow_16bit = std::make_unique<QDQ::DropQDQNodesSelector>(false);
//
// And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative
// scale will change the ordering of the elements between quantized & de-quantized values.
std::unique_ptr<NodeSelector> selector_no_16bit = std::make_unique<QDQ::DropQDQNodesSelector>(false);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name,
{{"MaxPool", {12}},
{"Resize", {}}},
std::move(selector_disallow_16bit),
{{"Resize", {}}},
std::move(selector_no_16bit),
std::move(drop_action_no_int16));

std::unique_ptr<NodeSelector> selector_no_16bit_and_positive_scale =
std::make_unique<QDQ::DropQDQNodesSelector>(false, true, false);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_and_positive_scale_name,
{{"MaxPool", {12}}},
std::move(selector_no_16bit_and_positive_scale),
std::move(drop_action_no_int16_and_positive_scale));

std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::DropQDQNodesSelector>(true);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name,
{{"Gather", {}},
Expand All @@ -70,6 +84,9 @@
std::move(drop_action));
#else
qdq_selector_action_registry.RegisterAction(drop_action_no_int16_name, std::move(drop_action_no_int16));
qdq_selector_action_registry.RegisterAction(
drop_action_no_int16_and_positive_scale_name,
std::move(drop_action_no_int16_and_positive_scale));
qdq_selector_action_registry.RegisterAction(drop_action_name, std::move(drop_action));
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return graph_viewer.GetConstantInitializer(initializer_name, true);
};

if (!allow_nonpositive_scale_) {
// IsQDQPairSupported will check that the scale is the same between q_node and dq_node.
if (!IsQOrDQScalePositiveConstantScalar(q_node, get_const_initializer, graph_viewer.ModelPath())) {
return false;
}
}

return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class NodeGroupSelector {
// Zero point and scale are constant scalars and must match
class DropQDQNodeGroupSelector : public NodeGroupSelector {
public:
explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true,
bool allow_nonpositive_scale = true)
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {}

private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
Expand All @@ -58,6 +59,7 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector {

bool allow_16bit_;
bool allow_4bit_;
bool allow_nonpositive_scale_;
};

// Single DQ -> node.
Expand Down Expand Up @@ -300,8 +302,8 @@ class BaseSelector : public NodeSelector {

class DropQDQNodesSelector : public BaseSelector {
public:
explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false)
: BaseSelector(std::make_unique<DropQDQNodeGroupSelector>(allow_16bit, allow_4bit)) {}
explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true)
: BaseSelector(std::make_unique<DropQDQNodeGroupSelector>(allow_16bit, allow_4bit, allow_nonpositive_scale)) {}
};

class DropDQNodesSelector : public BaseSelector {
Expand Down
46 changes: 46 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,52 @@ TEST(QDQTransformerTests, ReshapeDropQDQ) {
RunReshapeDropQDQTestCase<uint16_t>({1, 3, 2, 2}, {1, 12}, false, 21); // Use int16 ONNX QDQ ops
}

// Runs a test case that checks if Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q if the quantization scale is
// negative.
template <typename QuantType>
static void RunMaxPoolNegativeScaleDropQDQTestCase() {
auto build_test_case = [](ModelTestBuilder& builder) {
constexpr QuantType qmin = std::numeric_limits<QuantType>::min();
constexpr QuantType qmax = std::numeric_limits<QuantType>::max();

const std::vector<int64_t> input_shape = {1, 17, 17, 3};
auto* input_arg = builder.MakeInput<QuantType>(input_shape, qmin, qmax);
auto* output_arg = builder.MakeOutput();

constexpr float scale = -0.003f;
QuantType zero_point = 1 + (qmax + qmin) / 2;

auto* input_arg_dq = builder.MakeIntermediate();
auto* maxpool_output = builder.MakeIntermediate();

builder.AddDequantizeLinearNode<QuantType>(input_arg, scale, zero_point, input_arg_dq);

Node& maxpool_node = builder.AddNode("MaxPool", {input_arg_dq}, {maxpool_output});
maxpool_node.AddAttribute("auto_pad", "VALID");
maxpool_node.AddAttribute("kernel_shape", std::vector<int64_t>({2, 2}));

builder.AddQuantizeLinearNode<QuantType>(maxpool_output, scale, zero_point, output_arg);
};

auto check_graph = [](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["MaxPool"], 1);
EXPECT_EQ(op_to_count["QuantizeLinear"], 1);
EXPECT_EQ(op_to_count["DequantizeLinear"], 1);
};

constexpr int opset = 21;
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset);
}

// Checks that Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q for negative scale. Uses 8-bit and 16-bit Q/DQ ops.
TEST(QDQTransformerTests, MaxpoolDontDropQDQForNegativeScale) {
RunMaxPoolNegativeScaleDropQDQTestCase<int8_t>();
RunMaxPoolNegativeScaleDropQDQTestCase<uint8_t>();
RunMaxPoolNegativeScaleDropQDQTestCase<int16_t>();
RunMaxPoolNegativeScaleDropQDQTestCase<uint16_t>();
}

// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> (Un)Squeeze -> Q.
template <typename QuantType>
static void RunSqueezeUnsqueezeDropQDQTestCase(const std::string& squeeze_type,
Expand Down
Loading