Skip to content

Commit

Permalink
Quantize Bias for Conv/Gemm on Quantized Model (#22889)
Browse files Browse the repository at this point in the history
Some quantized models don't have Conv/Gemm node's bias quantized but
still leave them in float. This PR is to create a sub-graph to quantize
the bias for Conv/Gemm nodes with scale = scale_input_0 * scale_input_1
and zp = 0. We only do this for bias initializer so that ConstantFolding
will fold the sub-graph to a real quantized int32 bias initializer
during the graph optimization next round.
  • Loading branch information
centwang authored Nov 28, 2024
1 parent 42ecb05 commit 1128882
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 0 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#ifdef MLAS_TARGET_AMD64_IX86
#include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h"
#endif
#include "core/optimizer/qdq_transformer/bias_quantization.h"
#include "core/optimizer/qdq_transformer/clip_quantizelinear.h"
#include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h"
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
Expand Down Expand Up @@ -243,6 +244,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());
transformers.emplace_back(std::make_unique<BiasQuantization>());

// EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input
// condition that it ensures is important for the partitioning that happens after Level1 optimizers are run.
Expand Down
149 changes: 149 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/qdq_transformer/bias_quantization.h"

#include "core/common/common.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"

namespace onnxruntime {

Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const GraphViewer graph_viewer{graph};
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
auto* node_ptr = graph.GetNode(node_idx);
if (!node_ptr) {
continue;
}

Node& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));

const auto& input_defs = node.InputDefs();

// It's Conv/Gemm node with an initializer bias.
if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 || !input_defs[2]->Exists() ||
!graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) {
continue;
}

auto bias_shape = input_defs[2]->Shape();
if (!bias_shape || bias_shape->dim_size() != 1) {
continue;
}
int64_t bias_size = bias_shape->dim(0).dim_value();

// input_0 and input_1 are outputs of DequantizeLinear nodes.
const Node* parent_node_0 = graph.GetProducerNode(input_defs[0]->Name());
const Node* parent_node_1 = graph.GetProducerNode(input_defs[1]->Name());
if (!parent_node_0 || !parent_node_1 || parent_node_0->OpType() != QDQ::DQOpName ||
parent_node_1->OpType() != QDQ::DQOpName) {
continue;
}

Node& dq_0 = *graph.GetNode(parent_node_0->Index());
Node& dq_1 = *graph.GetNode(parent_node_1->Index());

// Currently we require input_0 is per-tensor scale.
if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) {
continue;
}

// For input_1, it's either per-tensor scale or per-channel scale on specific axis (0 for Conv and 1 for Gemm).
bool is_per_tensor_scale = true;
if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) {
is_per_tensor_scale = false;
auto weight_scale_shape = dq_1.InputDefs()[1]->Shape();
if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() ||
weight_scale_shape->dim(0).dim_value() != bias_size) {
continue;
}

const auto& dq_attrs = dq_1.GetAttributes();
if (dq_attrs.find("block_size") != dq_attrs.end()) {
continue;
}

int64_t axis = 1;
if (dq_attrs.find("axis") != dq_attrs.end()) {
axis = dq_attrs.at("axis").i();
}

int64_t expected_axis = 0;
if (node.OpType() == "Gemm") {
int64_t transB = 0;
if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) {
transB = attr->second.i();
}
expected_axis = transB == 0 ? 1 : 0;
}

if (axis != expected_axis) {
continue;
}
}

// Bias is quantized to int32.
ONNX_NAMESPACE::TypeProto int32_type_proto;
int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
auto scale_type = dq_1.InputDefs()[1]->TypeAsProto(); // Maybe per-tensor (scalar) or per-channel (1D) scale.
ONNX_NAMESPACE::TypeProto bias_dq_type;
bias_dq_type.mutable_tensor_type()->set_elem_type(scale_type->tensor_type().elem_type());
bias_dq_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(bias_size);

// scale = input_scale_0 * input_scale_1.
NodeArg& scale_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), scale_type);
Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node",
{dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr,
node.Domain());

// fp_bias / scale.
NodeArg& bias_div_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), &bias_dq_type);
Node& div_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node",
{node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain());
graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1);

// Round(fp_bias / scale).
NodeArg& bias_div_round_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), &bias_dq_type);
Node& round_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node",
{&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain());
graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0);

// Cast(round(fp_bias / scale)) to int32.
NodeArg& bias_int32_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto);
Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node",
{&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain());
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0);

// Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0.
NodeArg& bias_dq_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), &bias_dq_type);
Node& dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node",
{&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain());
if (!is_per_tensor_scale) {
dq_node.AddAttribute("axis", static_cast<int64_t>(0));
}

graph.AddEdge(cast_node.Index(), dq_node.Index(), 0, 0);
graph.AddEdge(mul_node.Index(), dq_node.Index(), 0, 1);
node.MutableInputDefs()[2] = &bias_dq_node_arg;
graph.AddEdge(dq_node.Index(), node.Index(), 0, 2);

modified = true;
}

return Status::OK();
}

} // namespace onnxruntime
27 changes: 27 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

/**
* @class BiasQuantization
*
* Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias
* with scale = scale_input_0 * scale_input_1 and zero_point = 0.
*
* Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed
* by a DequantizeLinear node.
*/
class BiasQuantization : public GraphTransformer {
public:
BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {}

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
91 changes: 91 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/graph/onnx_protobuf.h"
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/double_qdq_pairs_remover.h"
#include "core/optimizer/qdq_transformer/bias_quantization.h"
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
Expand Down Expand Up @@ -4846,5 +4847,95 @@ TEST(QDQTransformerTests, DropDQSelectorWithDQProducingGraphOutput) {
}
#endif // !defined(DISABLE_CONTRIB_OPS)

TEST(QDQTransformerTests, BiasQuantization_Conv) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 24, 128, 128}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
NodeArg* weight_arg = builder.MakeInitializer<uint8_t>({24, 1, 3, 3}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
NodeArg* bias_arg = builder.MakeInitializer<float>({24}, -0.1f, 0.1f);
NodeArg* input_dq_arg = builder.MakeIntermediate();
NodeArg* weight_dq_arg = builder.MakeIntermediate();
NodeArg* conv_dq_arg = builder.MakeIntermediate();
NodeArg* output_arg = builder.MakeOutput();

builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.07f, static_cast<uint8_t>(0), input_dq_arg,
use_contrib_qdq);
auto& weight_dq_node = builder.AddDequantizeLinearNode<uint8_t>(weight_arg, std::vector<float>(24, 0.05f),
std::vector<uint8_t>(24, static_cast<uint8_t>(0)),
weight_dq_arg, nullptr, use_contrib_qdq);
weight_dq_node.AddAttribute("axis", static_cast<int64_t>(0));
auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_dq_arg, bias_arg}, {conv_dq_arg});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
conv_node.AddAttribute("strides", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("group", static_cast<int64_t>(24));
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
builder.AddQuantizeLinearNode<uint8_t>(conv_dq_arg, 0.14f, static_cast<uint8_t>(127), output_arg,
use_contrib_qdq);
};

auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
EXPECT_EQ(op_to_count["QLinearConv"], 1);
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
};

test_case(false);
#if !defined(DISABLE_CONTRIB_OPS)
test_case(true);
#endif
}

TEST(QDQTransformerTests, BiasQuantization_Gemm) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg =
builder.MakeInput<uint8_t>({1, 32}, std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
NodeArg* weight_arg = builder.MakeInitializer<uint8_t>({16, 32}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
NodeArg* bias_arg = builder.MakeInitializer<float>({16}, -0.1f, 0.1f);
NodeArg* input_dq_arg = builder.MakeIntermediate();
NodeArg* weight_dq_arg = builder.MakeIntermediate();
NodeArg* gemm_dq_arg = builder.MakeIntermediate();
NodeArg* output_arg = builder.MakeOutput();

builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.001f, static_cast<uint8_t>(0), input_dq_arg,
use_contrib_qdq);
builder.AddDequantizeLinearNode<uint8_t>(weight_arg, 0.26f, static_cast<uint8_t>(0), weight_dq_arg,
use_contrib_qdq);
auto& gemm_node = builder.AddNode("Gemm", {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg});
gemm_node.AddAttribute("transB", static_cast<int64_t>(1));
builder.AddQuantizeLinearNode<uint8_t>(gemm_dq_arg, 0.144f, static_cast<uint8_t>(69), output_arg,
use_contrib_qdq);
};

auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
};

test_case(false);
#if !defined(DISABLE_CONTRIB_OPS)
test_case(true);
#endif
}

} // namespace test
} // namespace onnxruntime

0 comments on commit 1128882

Please sign in to comment.