Skip to content

Commit

Permalink
[QDQ Optimizer] Fix logic that drops Q/DQ ops from QDQ split node gro…
Browse files Browse the repository at this point in the history
…ups (#18394)

### Description
- Fix QDQ optimizer logic that drops Q/DQ ops from Split node groups so
that it only occurs when all input/output quantization parameters are
equal.
- Currently, the selector used for this optimization does not ensure
that all quantization parameters are equal.
- Support dropping Q/DQ ops from Split node groups with optional split
inputs (introduced opset 13). This was not working previously.


### Motivation and Context
Fix bugs in handling of QDQ Split node groups.

---------

Signed-off-by: adrianlizarraga <[email protected]>
  • Loading branch information
adrianlizarraga authored Nov 22, 2023
1 parent 62da3b1 commit 7c57305
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,19 @@ std::vector<NodeAndMoveInfo> WhereMoves() {
MoveAll(q, ArgType::kOutput)};
return moves;
}
QDQReplaceWithNew SplitReplacer() {
QDQReplaceWithNew SplitReplacer(bool has_split_as_input) {
NTO::NodeLocation dq{NTO::NodeType::kInput, 0};
NTO::NodeLocation target{NTO::NodeType::kTarget, 0};
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};
std::vector<NodeAndMoveInfo> moves{
MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput),
MoveAll(q, ArgType::kOutput)};
std::vector<NodeAndMoveInfo> moves{MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput)};

if (has_split_as_input) {
// Move the optional split input to the new node.
moves.push_back(MoveAndAppend(target, ArgType::kInput, 1, ArgType::kInput, true));
}

moves.push_back(MoveAll(q, ArgType::kOutput));

return QDQReplaceWithNew(kOnnxDomain, "Split", std::move(moves));
}

Expand Down Expand Up @@ -247,7 +254,12 @@ MatMulReplaceWithQLinear::MatMulReplaceWithQLinear()
}

Status SplitReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const {
return SplitReplacer().Run(graph, selected_nodes);
const auto& target_node = selected_nodes.Target();
const auto& input_defs = target_node.InputDefs();

// The 'split' attribute became an optional input at opset 13.
bool has_split_as_input = target_node.SinceVersion() >= 13 && input_defs.size() == 2;
return SplitReplacer(has_split_as_input).Run(graph, selected_nodes);
}

Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& selected_nodes) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
const std::string action_name{"dropSplitQDQ"};
std::unique_ptr<Action> action = std::make_unique<QDQ::SplitReplaceWithQuant>();
#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::OutputVariadicSelector>();
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::SplitSelector>(true /*req_equal_quant_params*/);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Split", {}}},
std::move(selector),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,39 @@ void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder
builder.num_input_defs = 1; // set to 1 as the first input is variadic
}

void OutputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) {
return false;
}

auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
return graph_viewer.GetConstantInitializer(initializer_name, true);
};

const Node& dq_node = *dq_nodes.front();
int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();

// All Q outputs should have same data type and (optionally) equal quantization parameters as the input.
for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) {
const Node& q_node = *q_nodes[q_idx];

if (dt_input != q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) {
return false;
}

if (req_equal_quant_params_ &&
!IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) {
return false;
}
}

return true;
}

void SplitSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
builder.num_output_defs = 1; // set to 1 as the first output is variadic
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,24 @@ class VariadicNodeGroupSelector : public NodeGroupSelector {
bool allow_16bit_;
};

// DQ node -> Split -> multiple Q nodes with equal quantization types.
// Optionally, the selector can require all input and output quantization parameters to be
// equal and constant.
class SplitNodeGroupSelector : public NodeGroupSelector {
public:
explicit SplitNodeGroupSelector(bool req_equal_quant_params = false)
: req_equal_quant_params_(req_equal_quant_params) {}

private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;

bool req_equal_quant_params_; // If true, only selects a node group if the input and output
// quantization parameters are all equal/constant, which enables the
// optimizer to drop the Q/DQ ops if the group is assigned to the CPU EP.
};

// DQ nodes for X, W and optionally B -> node -> Q
class ConvNodeGroupSelector : public NodeGroupSelector {
public:
Expand Down Expand Up @@ -288,10 +306,11 @@ class InputVariadicSelector : public BaseSelector {
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};

// DQ -> node -> Variadic Q nodes
class OutputVariadicSelector : public BaseSelector {
// DQ -> Split -> variadic Q nodes
class SplitSelector : public BaseSelector {
public:
OutputVariadicSelector() : BaseSelector(std::make_unique<VariadicNodeGroupSelector>()) {}
SplitSelector(bool req_equal_quant_params = false)
: BaseSelector(std::make_unique<SplitNodeGroupSelector>(req_equal_quant_params)) {}

void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops
}

/* static methods to return different operator's OpVersionMap */

// These are operators that do not change the data and therefore the input DQ and
// output Q have the same scale and zero_point.
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
return {{"Gather", {}},
{"Reshape", {}},
Expand All @@ -35,7 +38,6 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
{"Transpose", {}},
{"MaxPool", {12}},
{"Resize", {}},
{"Split", {}},
{"Squeeze", {}},
{"Unsqueeze", {}},
{"Tile", {}}};
Expand Down Expand Up @@ -97,6 +99,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() {
{"Max", {}},
{"Min", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetSplitOpVersionsMap() {
return {{"Split", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() {
return {{"Conv", {}}};
}
Expand Down Expand Up @@ -170,6 +175,13 @@ void RegisterVariadicSelectors(Selectors& qdq_selectors) {
std::move(selector));
}

void RegisterSplitSelector(Selectors& qdq_selectors) {
/* register selectors for Split op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<SplitNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetSplitOpVersionsMap(),
std::move(selector));
}

void RegisterConvSelector(Selectors& qdq_selectors) {
/* register selector for conv op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ConvNodeGroupSelector>();
Expand Down Expand Up @@ -247,6 +259,7 @@ void SelectorManager::CreateSelectors() {
RegisterUnarySelectors(qdq_selectors_);
RegisterBinarySelectors(qdq_selectors_);
RegisterVariadicSelectors(qdq_selectors_);
RegisterSplitSelector(qdq_selectors_);
RegisterConvSelector(qdq_selectors_);
RegisterConvTransposeSelector(qdq_selectors_);
RegisterMatMulSelector(qdq_selectors_);
Expand Down
37 changes: 26 additions & 11 deletions onnxruntime/test/optimizer/qdq_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,40 +466,55 @@ GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index, bool use_cont
}

template <typename InputType, typename OutputType>
GetQDQTestCaseFn BuildQDQSplitTestCase(
const std::vector<int64_t>& input_shape,
const int64_t& axis,
bool use_contrib_qdq = false) {
return [input_shape, axis, use_contrib_qdq](ModelTestBuilder& builder) {
GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector<int64_t>& input_shape,
const int64_t& axis,
bool use_diff_output_scale,
bool use_contrib_qdq = false) {
return [input_shape, axis, use_diff_output_scale, use_contrib_qdq](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<InputType>(input_shape,
std::numeric_limits<InputType>::min(),
std::numeric_limits<InputType>::max());

InputType dq_zp = std::numeric_limits<InputType>::max() / 2;
OutputType q_zp = std::numeric_limits<OutputType>::max() / 2;
auto* dq_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<InputType>(input_arg, .003f, dq_zp, dq_output, use_contrib_qdq);
constexpr float input_scale = 0.003f;
builder.AddDequantizeLinearNode<InputType>(input_arg, input_scale, dq_zp, dq_output, use_contrib_qdq);

// add Split
std::vector<NodeArg*> split_inputs;
split_inputs.push_back(dq_output);

// Use the optional 'split' input when testing Split 13
int opset = builder.DomainToVersionMap().find(kOnnxDomain)->second;
if (opset >= 13 && opset < 18) {
int64_t dim = input_shape[axis];
int64_t split_size = dim / 3;
split_inputs.push_back(builder.Make1DInitializer(std::vector<int64_t>{split_size,
split_size, dim - (2 * split_size)}));
}

auto* split_output_1 = builder.MakeIntermediate();
auto* split_output_2 = builder.MakeIntermediate();
auto* split_output_3 = builder.MakeIntermediate();
Node& split_node = builder.AddNode("Split", {dq_output}, {split_output_1, split_output_2, split_output_3});
Node& split_node = builder.AddNode("Split", split_inputs, {split_output_1, split_output_2, split_output_3});
split_node.AddAttribute("axis", axis);
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {

// Use the 'num_outputs' attribute when testing Split >= 18
if (opset >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(3));
}

// add Q
auto* q_split_output_1 = builder.MakeOutput();
auto* q_split_output_2 = builder.MakeOutput();
auto* q_split_output_3 = builder.MakeOutput();
builder.AddQuantizeLinearNode<OutputType>(split_output_1, .003f, q_zp, q_split_output_1,
float output_scale = use_diff_output_scale ? input_scale + 0.001f : input_scale;
builder.AddQuantizeLinearNode<OutputType>(split_output_1, output_scale, q_zp, q_split_output_1,
use_contrib_qdq); // Model input (node_token_1)
builder.AddQuantizeLinearNode<OutputType>(split_output_2, .003f, q_zp, q_split_output_2,
builder.AddQuantizeLinearNode<OutputType>(split_output_2, output_scale, q_zp, q_split_output_2,
use_contrib_qdq); // Model input (node_token_2)
builder.AddQuantizeLinearNode<OutputType>(split_output_3, .003f, q_zp, q_split_output_3,
builder.AddQuantizeLinearNode<OutputType>(split_output_3, output_scale, q_zp, q_split_output_3,
use_contrib_qdq);
};
}
Expand Down
44 changes: 34 additions & 10 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1210,27 +1210,51 @@ TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) {
// Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split.
template <typename InputQType, typename OutputQType>
static void RunDropSplitQDQTestCase(const std::vector<int64_t>& input_shape, int64_t axis,
bool use_contrib_qdq = false) {
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
bool all_same_quant_params, bool use_contrib_qdq = false) {
auto check_graph = [all_same_quant_params, use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
int expected_q_ops = all_same_quant_params ? 0 : 3;
int expected_dq_ops = all_same_quant_params ? 0 : 1;
EXPECT_EQ(op_to_count["Split"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_q_ops);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_dq_ops);
};
TransformerTester(BuildQDQSplitTestCase<InputQType, OutputQType>(input_shape, axis, use_contrib_qdq),
TransformerTester(BuildQDQSplitTestCase<InputQType, OutputQType>(input_shape, axis, !all_same_quant_params,
use_contrib_qdq),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
{12, 18, 19});
{12, 13, 18, 19}); // Test different ways to specify the split in each opset:
// 12 - split into equal parts without explicit 'split' attribute
// 13 - use optional 'split' input to split into 3 parts
// 18 - use 'num_outputs' attribute to split into 3 parts
// 19 - use 'num_outputs' attribute to split into 3 parts
}

// Test that DQ -> Split -> Q (many) is replaced with just Split for various quantization types.
TEST(QDQTransformerTests, Split) {
RunDropSplitQDQTestCase<int8_t, int8_t>({6, 18, 54}, 0);
RunDropSplitQDQTestCase<int8_t, int8_t>({6, 18, 54}, 0, true); // Use com.microsoft int8 QDQ ops
RunDropSplitQDQTestCase<int16_t, int16_t>({6, 18, 54}, 0, true); // Use com.microsoft int16 QDQ ops
RunDropSplitQDQTestCase<uint16_t, uint16_t>({6, 18, 54}, 0, true); // Use com.microsoft uint16 QDQ ops
// Test cases that drop Q/DQ ops from DQ -> Split -> Q (many).
// This happens when all the Q/DQ ops have equal and constant quantization parameters.
{
constexpr bool ALL_SAME_QUANT_PARAMS = true;
constexpr bool USE_CONTRIB_QDQ_OPS = true;
RunDropSplitQDQTestCase<int8_t, int8_t>({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS);
RunDropSplitQDQTestCase<int8_t, int8_t>({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS);
RunDropSplitQDQTestCase<int16_t, int16_t>({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS);
RunDropSplitQDQTestCase<uint16_t, uint16_t>({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS);
}

// Test cases that DO NOT drop Q/DQ ops from DQ -> Split -> Q (many)
// This happens when the Q/DQ ops do not have equal and constant quantization parameters.
{
constexpr bool DIFF_QUANT_PARAMS = false;
constexpr bool USE_CONTRIB_QDQ_OPS = true;
RunDropSplitQDQTestCase<int8_t, int8_t>({6, 18, 54}, 0, DIFF_QUANT_PARAMS);
RunDropSplitQDQTestCase<int8_t, int8_t>({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS);
RunDropSplitQDQTestCase<int16_t, int16_t>({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS);
RunDropSplitQDQTestCase<uint16_t, uint16_t>({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS);
}
}

// Because split isn't one the supported ops, this will stay the same
Expand Down

0 comments on commit 7c57305

Please sign in to comment.