Skip to content

Commit

Permalink
[Quantization] Fix scale/zero-point for 16-bit QDQ Softmax (#18589)
Browse files Browse the repository at this point in the history
### Description
Sets the appropriate scale and zero-point values for 16-bit QDQ Softmax.
Previously, the scale/zp were set to fixed values that were specific to 8-bit quantization.

### Motivation and Context
Generate more accurate 16-bit QDQ models that contain Softmax.
  • Loading branch information
adrianlizarraga authored Nov 28, 2023
1 parent 0b7048e commit 8d5ecc4
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 34 deletions.
28 changes: 16 additions & 12 deletions onnxruntime/python/tools/quantization/operators/softmax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import onnx

from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
from ..quant_utils import (
TENSOR_NAME_QUANT_SUFFIX,
QuantizedValue,
QuantizedValueType,
attribute_to_kwarg,
compute_scale_zp,
get_qmin_qmax_for_qType,
ms_domain,
)
from .base_operator import QuantOperatorBase
from .qdq_base_operator import QDQOperatorBase

Expand Down Expand Up @@ -77,15 +85,11 @@ def quantize(self):
class QDQSoftmax(QDQOperatorBase):
def quantize(self):
super().quantize()
if self.quantizer.activation_qType == onnx.onnx_pb.TensorProto.UINT8:
out_scale = 1 / 256.0
out_zero_point = 0
elif self.quantizer.is_activation_symmetric:
# results are all greater or equal to 0, so we can only use
# half of the range
out_scale = 1 / 127.0
out_zero_point = 0
else:
out_scale = 1 / 256.0
out_zero_point = -128
symmetric = self.quantizer.is_activation_symmetric

# Enforce Softmax range: 0.0 to 1.0
rmin, rmax = 0.0, 1.0
qmin, qmax = get_qmin_qmax_for_qType(self.quantizer.activation_qType, symmetric=symmetric)
out_zero_point, out_scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=symmetric)

self.quantizer.set_quant_scale_zp(self.node.output[0], (out_scale, out_zero_point))
3 changes: 3 additions & 0 deletions onnxruntime/test/python/quantization/op_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ def check_qtype_by_node_type(testcase, model_to_check, check_list):
model = onnx.load(model_to_check)
elif isinstance(model_to_check, onnx.ModelProto):
model = model_to_check
# NOTE: ONNX shape inference does not work on MS domain nodes.
# Therefore, this function cannot currently be used for graphs that contain ops such as
# com.microsoft.QuantizeLinear, which support 16-bit quantization.
model = onnx.shape_inference.infer_shapes(model)
value_infos = {vi.name: vi for vi in model.graph.value_info}
value_infos.update({ot.name: ot for ot in model.graph.output})
Expand Down
96 changes: 74 additions & 22 deletions onnxruntime/test/python/quantization/test_op_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def construct_model_conv_softmax(
softmax_input_shape,
softmax_attributes,
output_shape,
add_ms_domain_opset=False,
):
# (input)
# \
Expand Down Expand Up @@ -74,11 +75,16 @@ def construct_model_conv_softmax(
[identity_out, output_tensor],
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])

opset_imports = [helper.make_opsetid("", 13)]
if add_ms_domain_opset:
opset_imports.append(helper.make_opsetid("com.microsoft", 1))

model = helper.make_model(graph, opset_imports=opset_imports)
model.ir_version = 7 # use stable onnx ir version
onnx.save(model, output_model_path)

def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): # noqa: B006
def quantize_softmax_test_qop(self, activation_type, weight_type, extra_options={}): # noqa: B006
np.random.seed(1)
model_fp32_path = "softmax_fp32.onnx"
self.construct_model_conv_softmax(
Expand All @@ -91,11 +97,10 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}):
)
data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]})

activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8"
weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8"
activation_proto_qtype = activation_type.tensor_type
activation_type_str = str(activation_type)
weight_type_str = str(weight_type)
model_q8_path = f"softmax_{activation_type_str}{weight_type_str}.onnx"
model_q8_qdq_path = f"softmax_qdq_{activation_type_str}{weight_type_str}.onnx"

# Verify QOperator mode
data_reader.rewind()
Expand Down Expand Up @@ -138,19 +143,38 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}):
data_reader.rewind()
check_model_correctness(self, model_fp32_path, model_q8_path, data_reader.get_next())

def quantize_softmax_test_qdq(self, activation_type, weight_type, extra_options={}): # noqa: B006
np.random.seed(1)
model_fp32_path = "softmax_fp32.onnx"
self.construct_model_conv_softmax(
model_fp32_path,
[1, 2, 26, 42],
[3, 2, 3, 3],
[1, 3, 24, 40],
{"axis": -2},
[1, 3, 24, 40],
add_ms_domain_opset=extra_options.get("UseQDQContribOps", False),
)
data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]})

activation_proto_qtype = activation_type.tensor_type
activation_type_str = str(activation_type)
weight_type_str = str(weight_type)
model_qdq_path = f"softmax_qdq_{activation_type_str}{weight_type_str}.onnx"

# Verify QDQ mode
data_reader.rewind()
quantize_static(
model_fp32_path,
model_q8_qdq_path,
model_qdq_path,
data_reader,
quant_format=QuantFormat.QDQ,
activation_type=activation_type,
weight_type=weight_type,
extra_options=extra_options,
)

result_model = onnx.load(Path(model_q8_qdq_path))
result_model = onnx.load(Path(model_qdq_path))
qnode_cnt = 0
dqnode_cnt = 0
softmax_cnt = 0
Expand All @@ -166,40 +190,68 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}):
self.assertEqual(3, qnode_cnt, f"Expected 3 QuantizeLinear nodes, found {qnode_cnt}")
self.assertEqual(4, dqnode_cnt, f"Expected 4 DequantizeLinear nodes, found {dqnode_cnt}")
self.assertEqual(1, softmax_cnt, f"Expected 1 Softmax node, found {softmax_cnt}")
if extra_options.get("ActivationSymmetric", False):
for tensor in result_model.graph.initializer:
if tensor.name in qnode_zeropoints:
for tensor in result_model.graph.initializer:
if tensor.name in qnode_zeropoints:
self.assertEqual(
tensor.data_type,
activation_proto_qtype,
f"QuantizeLinear zero-point must be of proto type {activation_proto_qtype}, "
f"but found {tensor.data_type} instead.",
)
if extra_options.get("ActivationSymmetric", False):
np_value = numpy_helper.to_array(tensor)
self.assertEqual(
0,
np_value,
f"QuantizeLinear node zero point value must be 0, found {np_value} instead!",
)

qnode_io_qtypes = {
"QuantizeLinear": [
["i", 2, activation_proto_qtype],
["o", 0, activation_proto_qtype],
]
}
check_qtype_by_node_type(self, model_q8_qdq_path, qnode_io_qtypes)
data_reader.rewind()
check_model_correctness(self, model_fp32_path, model_q8_qdq_path, data_reader.get_next())
check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next())

def test_quantize_softmax(self):
self.quantize_softmax_test(QuantType.QUInt8, QuantType.QUInt8)
self.quantize_softmax_test_qop(QuantType.QUInt8, QuantType.QUInt8)
self.quantize_softmax_test_qdq(QuantType.QUInt8, QuantType.QUInt8)

def test_quantize_softmax_s8s8(self):
self.quantize_softmax_test(
self.quantize_softmax_test_qop(
QuantType.QInt8,
QuantType.QInt8,
)
self.quantize_softmax_test_qdq(
QuantType.QInt8,
QuantType.QInt8,
)
self.quantize_softmax_test_qop(
QuantType.QInt8,
QuantType.QInt8,
extra_options={"ActivationSymmetric": True},
)
self.quantize_softmax_test(
self.quantize_softmax_test_qdq(
QuantType.QInt8,
QuantType.QInt8,
extra_options={"ActivationSymmetric": True},
)

def test_quantize_softmax_qdq_u16u16(self):
self.quantize_softmax_test_qdq(
QuantType.QUInt16,
QuantType.QUInt16,
extra_options={"UseQDQContribOps": True},
)

def test_quantize_softmax_qdq_s16s16(self):
self.quantize_softmax_test_qdq(
QuantType.QInt16,
QuantType.QInt16,
extra_options={"UseQDQContribOps": True},
)
self.quantize_softmax_test_qdq(
QuantType.QInt16,
QuantType.QInt16,
extra_options={"UseQDQContribOps": True, "ActivationSymmetric": True},
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8d5ecc4

Please sign in to comment.