From 1128882bfd2a97c20f8a2a5ddb26cb0d42d9ebba Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 28 Nov 2024 10:10:24 +0800 Subject: [PATCH] Quantize Bias for Conv/Gemm on Quantized Model (#22889) Some quantized models don't have Conv/Gemm node's bias quantized but still leave them in float. This PR is to create a sub-graph to quantize the bias for Conv/Gemm nodes with scale = scale_input_0 * scale_input_1 and zp = 0. We only do this for bias initializer so that ConstantFolding will fold the sub-graph to a real quantized int32 bias initializer during the graph optimization next round. --- .../core/optimizer/graph_transformer_utils.cc | 2 + .../qdq_transformer/bias_quantization.cc | 149 ++++++++++++++++++ .../qdq_transformer/bias_quantization.h | 27 ++++ .../test/optimizer/qdq_transformer_test.cc | 91 +++++++++++ 4 files changed, 269 insertions(+) create mode 100644 onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc create mode 100644 onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index f769d31092d19..2f2524420dc44 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -63,6 +63,7 @@ #ifdef MLAS_TARGET_AMD64_IX86 #include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h" #endif +#include "core/optimizer/qdq_transformer/bias_quantization.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" @@ -243,6 +244,7 @@ InlinedVector> GenerateTransformers( if (!disable_quant_qdq) { transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); // EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input // condition that it ensures is important for the partitioning that happens after Level1 optimizers are run. diff --git a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc new file mode 100644 index 0000000000000..9e9665e14ede4 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/bias_quantization.h" + +#include "core/common/common.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" + +namespace onnxruntime { + +Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const GraphViewer graph_viewer{graph}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + auto* node_ptr = graph.GetNode(node_idx); + if (!node_ptr) { + continue; + } + + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + const auto& input_defs = node.InputDefs(); + + // It's Conv/Gemm node with an initializer bias. + if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 || !input_defs[2]->Exists() || + !graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) { + continue; + } + + auto bias_shape = input_defs[2]->Shape(); + if (!bias_shape || bias_shape->dim_size() != 1) { + continue; + } + int64_t bias_size = bias_shape->dim(0).dim_value(); + + // input_0 and input_1 are outputs of DequantizeLinear nodes. + const Node* parent_node_0 = graph.GetProducerNode(input_defs[0]->Name()); + const Node* parent_node_1 = graph.GetProducerNode(input_defs[1]->Name()); + if (!parent_node_0 || !parent_node_1 || parent_node_0->OpType() != QDQ::DQOpName || + parent_node_1->OpType() != QDQ::DQOpName) { + continue; + } + + Node& dq_0 = *graph.GetNode(parent_node_0->Index()); + Node& dq_1 = *graph.GetNode(parent_node_1->Index()); + + // Currently we require input_0 is per-tensor scale. + if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) { + continue; + } + + // For input_1, it's either per-tensor scale or per-channel scale on specific axis (0 for Conv and 1 for Gemm). + bool is_per_tensor_scale = true; + if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) { + is_per_tensor_scale = false; + auto weight_scale_shape = dq_1.InputDefs()[1]->Shape(); + if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() || + weight_scale_shape->dim(0).dim_value() != bias_size) { + continue; + } + + const auto& dq_attrs = dq_1.GetAttributes(); + if (dq_attrs.find("block_size") != dq_attrs.end()) { + continue; + } + + int64_t axis = 1; + if (dq_attrs.find("axis") != dq_attrs.end()) { + axis = dq_attrs.at("axis").i(); + } + + int64_t expected_axis = 0; + if (node.OpType() == "Gemm") { + int64_t transB = 0; + if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) { + transB = attr->second.i(); + } + expected_axis = transB == 0 ? 1 : 0; + } + + if (axis != expected_axis) { + continue; + } + } + + // Bias is quantized to int32. + ONNX_NAMESPACE::TypeProto int32_type_proto; + int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + auto scale_type = dq_1.InputDefs()[1]->TypeAsProto(); // Maybe per-tensor (scalar) or per-channel (1D) scale. + ONNX_NAMESPACE::TypeProto bias_dq_type; + bias_dq_type.mutable_tensor_type()->set_elem_type(scale_type->tensor_type().elem_type()); + bias_dq_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(bias_size); + + // scale = input_scale_0 * input_scale_1. + NodeArg& scale_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), scale_type); + Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node", + {dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr, + node.Domain()); + + // fp_bias / scale. + NodeArg& bias_div_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), &bias_dq_type); + Node& div_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node", + {node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain()); + graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1); + + // Round(fp_bias / scale). + NodeArg& bias_div_round_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), &bias_dq_type); + Node& round_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node", + {&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain()); + graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0); + + // Cast(round(fp_bias / scale)) to int32. + NodeArg& bias_int32_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto); + Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node", + {&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain()); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32)); + graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0); + + // Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0. + NodeArg& bias_dq_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), &bias_dq_type); + Node& dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node", + {&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain()); + if (!is_per_tensor_scale) { + dq_node.AddAttribute("axis", static_cast(0)); + } + + graph.AddEdge(cast_node.Index(), dq_node.Index(), 0, 0); + graph.AddEdge(mul_node.Index(), dq_node.Index(), 0, 1); + node.MutableInputDefs()[2] = &bias_dq_node_arg; + graph.AddEdge(dq_node.Index(), node.Index(), 0, 2); + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h new file mode 100644 index 0000000000000..0297def260fd9 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + * @class BiasQuantization + * + * Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias + * with scale = scale_input_0 * scale_input_1 and zero_point = 0. + * + * Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed + * by a DequantizeLinear node. + */ +class BiasQuantization : public GraphTransformer { + public: + BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d07977d4b97b8..cfee4a83a4292 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -11,6 +11,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" +#include "core/optimizer/qdq_transformer/bias_quantization.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -4846,5 +4847,95 @@ TEST(QDQTransformerTests, DropDQSelectorWithDQProducingGraphOutput) { } #endif // !defined(DISABLE_CONTRIB_OPS) +TEST(QDQTransformerTests, BiasQuantization_Conv) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = builder.MakeInput({1, 24, 128, 128}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({24, 1, 3, 3}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* bias_arg = builder.MakeInitializer({24}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* weight_dq_arg = builder.MakeIntermediate(); + NodeArg* conv_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.07f, static_cast(0), input_dq_arg, + use_contrib_qdq); + auto& weight_dq_node = builder.AddDequantizeLinearNode(weight_arg, std::vector(24, 0.05f), + std::vector(24, static_cast(0)), + weight_dq_arg, nullptr, use_contrib_qdq); + weight_dq_node.AddAttribute("axis", static_cast(0)); + auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_dq_arg, bias_arg}, {conv_dq_arg}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + conv_node.AddAttribute("strides", std::vector{1, 1}); + conv_node.AddAttribute("group", static_cast(24)); + conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + builder.AddQuantizeLinearNode(conv_dq_arg, 0.14f, static_cast(127), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["QLinearConv"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + +TEST(QDQTransformerTests, BiasQuantization_Gemm) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = + builder.MakeInput({1, 32}, std::numeric_limits::min(), std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({16, 32}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* bias_arg = builder.MakeInitializer({16}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* weight_dq_arg = builder.MakeIntermediate(); + NodeArg* gemm_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.001f, static_cast(0), input_dq_arg, + use_contrib_qdq); + builder.AddDequantizeLinearNode(weight_arg, 0.26f, static_cast(0), weight_dq_arg, + use_contrib_qdq); + auto& gemm_node = builder.AddNode("Gemm", {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg}); + gemm_node.AddAttribute("transB", static_cast(1)); + builder.AddQuantizeLinearNode(gemm_dq_arg, 0.144f, static_cast(69), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + } // namespace test } // namespace onnxruntime