Skip to content

Commit

Permalink
Fix the duplicated QDQ attributes setup issue (#18039)
Browse files Browse the repository at this point in the history
### Description
The copied QDQ node should have exactly the same attributes as the
original QDQ node. Otherwise, it might cause errors when the original
node has attributes that use non default values (such as axis != 1
case).

An example user case is like:
A DequantizeLinear node has more than 1 consumer in the graph, and its
attributes axis is 0.

### Motivation and Context
I see the errors like 
#16188 
and this fix could solve the issue.
  • Loading branch information
vera121 authored Jan 11, 2024
1 parent fd6bab4 commit 5678317
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Status DuplicateDQForOutputEdge(const graph_utils::GraphEdge& original_dq_output
MakeString("Added by ", kTransformerName),
dq_inputs,
{&new_dq_output_nodearg},
nullptr, // attributes
&original_dq_node.GetAttributes(),
original_dq_node.Domain());

// set up edges
Expand Down
40 changes: 40 additions & 0 deletions onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,44 @@ TEST(EnsureUniqueDQForNodeUnitTests, QDQWithMultiConsumerDQNodes) {
EXPECT_EQ(OpCount(op_count_before, "DequantizeLinear") + 4, OpCount(op_count_after, "DequantizeLinear"));
}

TEST(EnsureUniqueDQForNodeUnitTests, QDQWithMultiConsumerDQNodesPreservingAttributes) {
constexpr auto model_uri = ORT_TSTR("testdata/qdq_with_multi_consumer_q_dq_axis.onnx");

SessionOptions session_options{};
// test interaction with level 1 transformers
session_options.graph_optimization_level = TransformerLevel::Level1;

InferenceSessionWrapper session{session_options, GetEnvironment()};

ASSERT_STATUS_OK(session.Load(model_uri));

const auto op_count_before = CountOpsInGraph(session.GetGraph());

ASSERT_STATUS_OK(session.Initialize());

const auto op_count_after = CountOpsInGraph(session.GetGraph());

EXPECT_EQ(OpCount(op_count_before, "DequantizeLinear") + 8, OpCount(op_count_after, "DequantizeLinear"));

int64_t given_axis = 0; // all the following 4 DQ nodes and their duplicated one should have axis = 0
std::string axis_dq_name0 = "Convolution28_Output_0/fusedmuladd_B/DequantizeLinear";
std::string axis_dq_name1 = "Parameter5/DequantizeLinear";
std::string axis_dq_name2 = "Convolution110_Output_0/fusedmuladd_B/DequantizeLinear";
std::string axis_dq_name3 = "Parameter87/DequantizeLinear";
for (const auto& node : session.GetGraph().Nodes()) {
if (node.OpType() == "DequantizeLinear") {
if (node.Name().find(axis_dq_name0) == 0 ||
node.Name().find(axis_dq_name1) == 0 ||
node.Name().find(axis_dq_name2) == 0 ||
node.Name().find(axis_dq_name3) == 0) {
const auto& attrs = node.GetAttributes();
ASSERT_TRUE(attrs.find("axis") != attrs.end());
const auto& axis_attr = attrs.at("axis");
int64_t axis = axis_attr.i();
EXPECT_EQ(axis, given_axis);
}
}
}
}

} // namespace onnxruntime::test
Binary file not shown.

0 comments on commit 5678317

Please sign in to comment.