Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN Quant] Ensure 16bit tensor quant overrides set MS domain #19684

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(
if self.mode not in QuantizationMode:
raise ValueError(f"unsupported quantization mode {self.mode}")

self.tensor_quant_overrides = self._get_and_check_tensor_quant_overrides()
self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides()
self.quantization_params = self.calculate_quantization_params()

# QuantizeRange tensor name and zero tensor name for scale and zero point calculation.
Expand All @@ -177,8 +177,10 @@ def __init__(
def _get_and_check_tensor_quant_overrides(self):
"""
Get tensor quantization overrides and check correctness.
Also returns a set of quantization types (as TensorProto) specified across all overrides.
"""
tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {})
tensor_quant_override_types = set()

# Validate that compatible/valid overrides are provided.
if tensor_quant_overrides:
Expand Down Expand Up @@ -211,6 +213,8 @@ def _get_and_check_tensor_quant_overrides(self):
# other channels.
if index == 0:
quant_type = quant_overrides.get("quant_type")
if quant_type is not None:
tensor_quant_override_types.add(quant_type.tensor_type)
elif quant_type != quant_overrides.get("quant_type"):
raise ValueError(
"Channel quantization types for tensor '{tensor_name}' do not match at index {index}."
Expand All @@ -231,7 +235,7 @@ def _get_and_check_tensor_quant_overrides(self):
f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'"
)

return tensor_quant_overrides
return tensor_quant_overrides, tensor_quant_override_types

def get_per_tensor_quant_overrides(self, tensor_name):
quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}])
Expand Down Expand Up @@ -747,8 +751,7 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non
raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}")
scale_values = np.array([params["scale"]])
assert scale_values.dtype != np.float64
# zero_point_type = params["quant_type"]
assert zero_point_type == params["quant_type"]
zero_point_type = params["quant_type"]
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
else:
zero_point_values = np.array([use_zeropoint])
scale_values = np.array([use_scale])
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/python/tools/quantization/qdq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def __init__(
# if the activation or weight types are 16-bit integers.
# TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support.
int16_types = (TensorProto.UINT16, TensorProto.INT16)
if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types):
overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types)
if not self.qdq_op_domain and (
self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16
):
logging.warning(
"ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. "
f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from onnxruntime import quantization
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config
from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType
from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain


class DummyDataReader(quantization.CalibrationDataReader):
Expand Down Expand Up @@ -423,6 +423,36 @@ def test_qdq_overrides_per_channel2(self):
self.assertEqual(zp, expected_zp)
self.assertEqual(scale, np.float32(expected_scale))

def test_16bit_overrides_set_ms_domain(self):
"""
Test that overriding a tensor to 16bit (when default is 8bit) automatically sets the 'com.microsoft'
domain on DQ and Q ops.
"""
qdq_model_name = "model_quant_overrides_to_16bit.onnx"
inp_zp, _, sig_out_zp, _, _, _, _, _, out_zp, _ = self.perform_qdq_quantization(
qdq_model_name,
activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations
extra_options={
"TensorQuantOverrides": {
"INP": [{"quant_type": quantization.QuantType.QUInt16}],
"SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}],
}
},
)

# Input and Sigmoid's output should be overridden to 16bit
self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16)
self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16)

# Output should the default uint8 type
self.assertEqual(out_zp.data_type, onnx.TensorProto.UINT8)

# Q/DQ ops should all have the 'com.microsoft' domain
qdq_model = onnx.load_model(qdq_model_name)
for node in qdq_model.graph.node:
if node.op_type in {"QuantizeLinear", "DequantizeLinear"}:
self.assertEqual(node.domain, ms_domain)

def test_override_validation_nonexisting_tensor(self):
"""
Test that specifying a non-existing tensor should fail.
Expand Down
Loading