Skip to content

Commit

Permalink
Fix Pad's quantization (microsoft#17807)
Browse files Browse the repository at this point in the history
Fix microsoft#17760. Upstream exporter creates empty string as Pad's 3rd input
and the quantization tool 1) considers that as a valid tensor name and
2) adds corresponding invalid quantization nodes. This PR adds a
condition check to make quantization tool working.
  • Loading branch information
wschin authored and kleiti committed Mar 22, 2024
1 parent 0f23af5 commit 7f3c99e
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 2 deletions.
1 change: 1 addition & 0 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
:return: List of newly created nodes in NodeProto format.
"""
input_name = node.input[input_index]
assert input_name != "", "Cannot access undefined variable in graph."
output_name = input_name + TENSOR_NAME_QUANT_SUFFIX
ql_node_name = input_name + "_QuantizeLinear"

Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/python/tools/quantization/operators/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def quantize(self):
kwargs.update(kv)

if "mode" not in kwargs or kwargs["mode"] == b"constant":
if len(node.input) > 2: # There is 3rd input 'constant_value'
if len(node.input) > 2 and node.input[2] != "": # There is 3rd input 'constant_value'
zp_tensor = self.quantizer.model.get_initializer(quantized_input_value.zp_name)
scale_tensor = self.quantizer.model.get_initializer(quantized_input_value.scale_name)
if zp_tensor is None or scale_tensor is None:
Expand Down Expand Up @@ -72,7 +72,17 @@ def quantize(self):
self.quantizer.new_nodes.extend(pad_value_qnodes)
node.input[2] = pad_value_qnodes[0].output[0]
else:
node.input.extend([quantized_input_value.zp_name]) # pad zero_point for original zero
# In quantized format, the `zero` before quantization is mapped
# to quantized_input_value.zp_name. Thus, padding 0 to
# original tensor should become padding zero point to quantized
# tensor.
if len(node.input) == 2:
# Feed quantization's zero point to padding node.
node.input.append(quantized_input_value.zp_name)
else:
# Assign quantization's zero point to padding node.
assert node.input[2] == ""
node.input[2] = quantized_input_value.zp_name

# Create an entry for output quantized value
quantized_output_value = QuantizedValue(
Expand Down
118 changes: 118 additions & 0 deletions onnxruntime/test/python/quantization/test_op_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# license information.
# --------------------------------------------------------------------------

import itertools
import unittest

import numpy as np
Expand Down Expand Up @@ -404,6 +405,123 @@ def test_static_mode_constant_value_edge_case(self):
"constant", constant_value=0.1, quantize_mode="static", extra_options={"dual_feed": True}
)

@classmethod
def construct_model_add_pad_add(
cls,
# Name of model input, i.e., "input" in the illustration graph below.
name,
# Name of model output.
final_name,
# model input shape.
shape,
):
# Graph implemented below is
# `name`, `name` -> Add -> "first_add_output"
# "first_add_output", "pads" -> Pad -> "pad_output"
# "pad_output", "pad_output" -> Add -> `final_name`
# where `name` is the 2nd argument of this function,
# `final_name` is the 3rd argument of this function,
# and the rest lowercase strings are tensor names in the graph.

input_name = name
first_add_output_name = "first_add_output"
pads_name = "pads"
pad_output_name = "pad_output"
second_add_output_name = final_name

input_shape = shape
input_rank = len(input_shape)

input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape)

first_add_node = helper.make_node(
"Add",
[input_name, input_name],
[first_add_output_name],
name="FirstAdd",
)

pads = [1, 2] * input_rank
pads_initializer = helper.make_tensor(
pads_name,
TensorProto.INT64,
# 1-D tensor of shape [2 * input_rank].
[len(pads)],
pads,
)
pad_node = helper.make_node(
"Pad",
[first_add_output_name, pads_name, ""],
[pad_output_name],
name="PadNode",
mode="constant",
)
pad_output_shape = tuple(input_shape[i] + pads[i] + pads[i + input_rank] for i in range(input_rank))

second_add_node = helper.make_node(
"Add",
[pad_output_name, pad_output_name],
[second_add_output_name],
name="SecondAdd",
)

output_tensor = helper.make_tensor_value_info(second_add_output_name, TensorProto.FLOAT, pad_output_shape)

graph = helper.make_graph(
[first_add_node, pad_node, second_add_node],
"TestPadWithEmptyStringInput",
[input_tensor],
[output_tensor],
initializer=[pads_initializer],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
model.ir_version = 7 # use stable onnx ir version

return model

def test_pad_with_empty_string_input_name(self):
np.random.seed(108)
model_fp32_path = "pad_with_empty_string_input_name_fp32.onnx"
model_i8_path = "pad_with_empty_string_input_name_i8.onnx"

shape = [
3,
]
name = "input"
data_reader = self.input_feeds(
1,
{
name: shape,
},
)

model_fp32 = TestOpQuatizerPad.construct_model_add_pad_add(name=name, shape=shape, final_name="output")

onnx.save(model_fp32, model_fp32_path)

self.quantize_model(
model_fp32_path,
model_i8_path,
data_reader=data_reader,
)

model_i8 = onnx.load(model_i8_path)

# Assert quantization really happens.
self.assertEqual(model_i8.graph.node[0].op_type, "QuantizeLinear")
self.assertEqual(model_i8.graph.node[1].op_type, "QLinearAdd")
self.assertEqual(model_i8.graph.node[2].op_type, "Pad")
self.assertEqual(model_i8.graph.node[3].op_type, "QLinearAdd")
self.assertEqual(model_i8.graph.node[4].op_type, "DequantizeLinear")

for node in model_i8.graph.node:
# Examine no empty string flows to quantization process.
# Previously, optional input specified by `""` in NodeProto.input
# may cause phantom node to generate `"_quantized"` in quantization process.
for name in itertools.chain(node.input, node.output):
self.assertNotEqual(name, "")
self.assertNotEqual(name, "_quantized")


if __name__ == "__main__":
unittest.main()

0 comments on commit 7f3c99e

Please sign in to comment.