Skip to content

Commit

Permalink
Keep QDQ nodes w/ nonpositive scale around MaxPool
Browse files Browse the repository at this point in the history
Currently, the DropQDQNodesRules optimization removes QuantizeLinear and
DequantizeLinear nodes from DequantizeLinear∘MaxPool∘QuantizeLinear.
However, if the x_scale/y_scale values are non-positive, this changes
the ordering of the elements in the input value, so this optimization is
changing the results.

This change adds a check for whether the scale in the QuantizeLinear (or
DequantizeLinear) is a positive scalar, and a new selector to disallow
removing the QDQ around MaxPool if it is not.

#21176
  • Loading branch information
mcollinswisc committed Jun 26, 2024
1 parent e2abba1 commit 2160e44
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 7 deletions.
35 changes: 35 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,41 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint(
return true;
}

bool IsQOrDQScalePositiveConstantScalar(
const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer,
const Path& model_path) {
auto q_or_dq_input_defs = q_or_dq_node.InputDefs();

ORT_ENFORCE(q_or_dq_input_defs.size() >= 2);

if (!optimizer_utils::IsScalar(*q_or_dq_input_defs[InputIndex::SCALE_ID])) {
return false;
}

const ONNX_NAMESPACE::TensorProto* q_or_dq_scale_tensor_proto =
get_const_initializer(q_or_dq_input_defs[InputIndex::SCALE_ID]->Name());
if (nullptr == q_or_dq_scale_tensor_proto) {
return false;
}

Initializer q_or_dq_scale(*q_or_dq_scale_tensor_proto, model_path);

switch (q_or_dq_scale.data_type()) {
case ONNX_NAMESPACE::TensorProto::FLOAT:
return q_or_dq_scale.data<float>()[0] > 0;

case ONNX_NAMESPACE::TensorProto::FLOAT16:
return q_or_dq_scale.data<MLFloat16>()[0] > 0;

case ONNX_NAMESPACE::TensorProto::BFLOAT16:
return q_or_dq_scale.data<BFloat16>()[0] > 0;

default:
assert(false);
return false;
}
}

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

bool MatchQNode(const Node& node) {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint(
const GetConstantInitializerFn& get_const_initializer,
bool& zero_point_exists);

// Checks that the y_scale/x_scale input to the QuantizeLinear/DequantizeLinear node is a positive scalar.
bool IsQOrDQScalePositiveConstantScalar(const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer,
const Path& model_path);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// Check Q node op type, version, and domain.
bool MatchQNode(const Node& node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
// 3 nodes. DQ, target, Q. Merge into target and remove DQ and Q.
const std::string drop_action_name{"drop"};
const std::string drop_action_no_int16_name{"drop_no_int16_support"};
const std::string drop_action_no_int16_nor_nonpositive_scale_name{"drop_no_int16_support_no_nonpositive_scale"};
NTO::NodeLocation dq{NTO::NodeType::kInput, 0};
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};

Expand All @@ -46,19 +47,32 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {

std::unique_ptr<Action> drop_action_no_int16 = std::make_unique<MergeIntoTargetFixed>(
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action_no_int16_nor_nonpositive_scale = std::make_unique<MergeIntoTargetFixed>(
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action = std::make_unique<MergeIntoTargetFixed>(std::move(moves));

#if !defined(ORT_MINIMAL_BUILD)
// Use a separate selector + action that disallows 16-bit types for MaxPool and Resize.
// Use a separate selectors & actions for MaxPool and Resize.
//
// They disallow 16-bit types for MaxPool and Resize:
// int16 MaxPool is not supported by the ONNX specification.
// int16 Resize is not supported by the ORT implementation (although allowed by ONNX).
//
// And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative scale will change the ordering
// of the elements between quantized & de-quantized values.
std::unique_ptr<NodeSelector> selector_disallow_16bit = std::make_unique<QDQ::DropQDQNodesSelector>(false);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name,
{{"MaxPool", {12}},
{"Resize", {}}},
{{"Resize", {}}},
std::move(selector_disallow_16bit),
std::move(drop_action_no_int16));

std::unique_ptr<NodeSelector> selector_disallow_16bit_and_nonpositive_scale = (
std::make_unique<QDQ::DropQDQNodesSelector>(false, true, false));
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_nor_nonpositive_scale_name,
{{"MaxPool", {12}}},
std::move(selector_disallow_16bit_and_nonpositive_scale),
std::move(drop_action_no_int16_nor_nonpositive_scale));

std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::DropQDQNodesSelector>(true);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name,
{{"Gather", {}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return graph_viewer.GetConstantInitializer(initializer_name, true);
};

if (!allow_nonpositive_scale_) {
// IsQDQPairSupported will check that the scale is the same between q_node and dq_node.
if (!IsQOrDQScalePositiveConstantScalar(q_node, get_const_initializer, graph_viewer.ModelPath())) {
return false;
}
}

return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class NodeGroupSelector {
// Zero point and scale are constant scalars and must match
class DropQDQNodeGroupSelector : public NodeGroupSelector {
public:
explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true, bool allow_nonpositive_scale = true)
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {}

private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
Expand All @@ -58,6 +58,7 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector {

bool allow_16bit_;
bool allow_4bit_;
bool allow_nonpositive_scale_;
};

// Single DQ -> node.
Expand Down Expand Up @@ -292,8 +293,8 @@ class BaseSelector : public NodeSelector {

class DropQDQNodesSelector : public BaseSelector {
public:
explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false)
: BaseSelector(std::make_unique<DropQDQNodeGroupSelector>(allow_16bit, allow_4bit)) {}
explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true)
: BaseSelector(std::make_unique<DropQDQNodeGroupSelector>(allow_16bit, allow_4bit, allow_nonpositive_scale)) {}
};

class DropDQNodesSelector : public BaseSelector {
Expand Down

0 comments on commit 2160e44

Please sign in to comment.