Skip to content

Commit

Permalink
Handle all float types in IsQDQPairSupported (#19085)
Browse files Browse the repository at this point in the history
### Description
This makes detection of identical QDQ scales work with float16 and
bfloat16 rather than failing.


### Motivation and Context
This addresses failures in customer models
  • Loading branch information
jeffbloo authored Jan 11, 2024
1 parent 8a0a972 commit 08cf4fb
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 31 deletions.
23 changes: 20 additions & 3 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,26 @@ bool IsQDQPairSupported(
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
Initializer dq_scale(*dq_scale_tensor_proto, model_path);

return q_zp.data_type() == dq_zp.data_type() &&
SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan()) &&
*q_scale.data<float>() == *dq_scale.data<float>();
if (q_zp.data_type() != dq_zp.data_type() ||
q_scale.data_type() != q_scale.data_type() ||
!SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan())) {
return false;
}

switch (q_scale.data_type()) {
case ONNX_NAMESPACE::TensorProto::FLOAT:
return *q_scale.data<float>() == *dq_scale.data<float>();

case ONNX_NAMESPACE::TensorProto::FLOAT16:
return *q_scale.data<MLFloat16>() == *dq_scale.data<MLFloat16>();

case ONNX_NAMESPACE::TensorProto::BFLOAT16:
return *q_scale.data<BFloat16>() == *dq_scale.data<BFloat16>();

default:
assert(false);
return false;
}
}

bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) {
Expand Down
61 changes: 33 additions & 28 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4602,38 +4602,43 @@ TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) {
}

// Test DoubleQDQPairsRemover to remove unnecessary DQ->Q nodes in the middle
TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_optimization/dup_qdq.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ_Float16) {
auto RunTest = [this](const ORTCHAR_T* model_uri) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<DoubleQDQPairsRemover>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<DoubleQDQPairsRemover>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["QuantizeLinear"], 3);
EXPECT_EQ(op_to_count["DequantizeLinear"], 4);
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["QuantizeLinear"], 3);
EXPECT_EQ(op_to_count["DequantizeLinear"], 4);

std::string dq_scale_name_before_reshape_node;
std::string zp_name_before_reshape_node;
std::string dq_scale_name_after_reshape_node;
std::string zp_name_after_reshape_node;
for (auto& node : graph.Nodes()) {
if (node.Name() == "dq_2") {
dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
}
if (node.Name() == "q_3") {
dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
std::string dq_scale_name_before_reshape_node;
std::string zp_name_before_reshape_node;
std::string dq_scale_name_after_reshape_node;
std::string zp_name_after_reshape_node;
for (auto& node : graph.Nodes()) {
if (node.Name() == "dq_2") {
dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
}
if (node.Name() == "q_3") {
dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
}
}
}
EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false);
EXPECT_EQ(zp_name_before_reshape_node.empty(), false);
EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node);
EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node);
EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false);
EXPECT_EQ(zp_name_before_reshape_node.empty(), false);
EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node);
EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node);
};

RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq.onnx");
RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_float16.onnx");
RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_bfloat16.onnx");
}

// Test Gelu -> FastGelu
Expand Down
Binary file not shown.
Binary file not shown.

0 comments on commit 08cf4fb

Please sign in to comment.