Skip to content

Commit

Permalink
improve robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Dec 25, 2023
1 parent 1c8ae86 commit 1185de0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 10 deletions.
37 changes: 35 additions & 2 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .calibrate import TensorData
from .onnx_model import ONNXModel
from .quant_utils import (
ONNX_TYPE_TO_NP_TYPE,
TENSOR_NAME_QUANT_SUFFIX,
QuantizationMode,
QuantizedValue,
Expand Down Expand Up @@ -1129,8 +1130,15 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei
qType = quant_overrides["quant_type"].tensor_type # noqa: N806

if "scale" in quant_overrides and "zero_point" in quant_overrides:
zero_point, scale = quant_overrides["zero_point"], quant_overrides["scale"]
zero_point = np.array(quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[qType])
scale = np.array(quant_overrides["scale"])
q_weight_data = quantize_nparray(qType, weight_data.flatten(), scale, zero_point)
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
assert (
zero_point.dtype != np.float32 and zero_point.dtype != np.float16
), f"Unexpected dtype {zero_point.dtype}"
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"

else:
_, _, zero_point, scale, q_weight_data = quantize_data(
weight_data.flatten(),
Expand All @@ -1142,6 +1150,12 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei
rmax_override=quant_overrides.get("rmax"),
)

assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
assert (
zero_point.dtype != np.float32 and zero_point.dtype != np.float16
), f"Unexpected dtype {zero_point.dtype}"
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"

scale_dtype = weight.data_type
scale_initializer = onnx.helper.make_tensor(scale_name, scale_dtype, [], scale.reshape((-1,)).tolist())
zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], zero_point.reshape((-1,)).tolist())
Expand Down Expand Up @@ -1222,10 +1236,20 @@ def quantize_weight_per_channel(
channel_quant_overrides = quant_overrides_for_channels[i]

if "scale" in channel_quant_overrides and "zero_point" in channel_quant_overrides:
zero_point, scale = channel_quant_overrides["zero_point"], channel_quant_overrides["scale"]
zero_point = np.array(channel_quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[weight_qType])
scale = np.array(channel_quant_overrides["scale"])
quantized_per_channel_data = quantize_nparray(
weight_qType, per_channel_data.flatten(), scale, zero_point
)
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
assert (
zero_point.dtype != np.float32 and zero_point.dtype != np.float16
), f"Unexpected dtype {zero_point.dtype}"
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
assert isinstance(
quantized_per_channel_data, np.ndarray
), f"Unexpected type {type(quantized_per_channel_data)}"

else:
symmetric = channel_quant_overrides.get(
"symmetric",
Expand All @@ -1244,6 +1268,15 @@ def quantize_weight_per_channel(
rmax_override=channel_quant_overrides.get("rmax"),
)

assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
assert (
zero_point.dtype != np.float32 and zero_point.dtype != np.float16
), f"Unexpected dtype {zero_point.dtype}"
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
assert isinstance(
quantized_per_channel_data, np.ndarray
), f"Unexpected type {type(quantized_per_channel_data)}"

zero_point_list.append(zero_point)
scale_list.append(scale)
quantized_per_channel_data_list.append(quantized_per_channel_data)
Expand Down
28 changes: 21 additions & 7 deletions onnxruntime/python/tools/quantization/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ def from_string(format):
}


def _check_type(*args):
new_args = []
for i, a in enumerate(args):
if numpy.issubdtype(type(a), numpy.number):
new_args.append(numpy.array(a))
elif isinstance(a, numpy.ndarray):
new_args.append(a)
else:
raise TypeError(f"arg {i} is not an array: {a}")
return tuple(new_args) if len(new_args) > 1 else new_args[0]


def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
assert (
qType in ONNX_TYPE_TO_NP_TYPE
Expand Down Expand Up @@ -178,7 +190,7 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
)
)
ref = ReferenceEvaluator(onnx_model)
return ref.run(None, {"X": arr, "scale": scale})[0]
return _check_type(ref.run(None, {"X": arr, "scale": scale})[0])
else:
dtype = ONNX_TYPE_TO_NP_TYPE[qType]
(qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True)
Expand All @@ -187,7 +199,7 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
cliphigh = min(qmax, high) if high is not None else qmax
arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point)
numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32)
return arr_fp32.astype(dtype)
return _check_type(arr_fp32.astype(dtype))


def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=None):
Expand Down Expand Up @@ -300,15 +312,17 @@ def quantize_data(
if rmin_override is not None:
rmin = rmin_override
else:
rmin = 1.0
rmin = 0.0

if rmax_override is not None:
rmax = rmax_override
else:
rmax = 1.0
rmax = 0.0

rmin = numpy.array(rmin, dtype=data.dtype)
rmax = numpy.array(rmax, dtype=data.dtype)
zero_point = 0
scale = 1.0
scale = numpy.array(1.0, dtype=data.dtype)
if len(data):
rmin = data.min()
rmax = data.max()
Expand All @@ -325,14 +339,14 @@ def quantize_data(
f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], "
f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]."
)
return rmin, rmax, zero_point, scale, quantized_data
return _check_type(rmin, rmax, zero_point, scale, quantized_data)

if qType in (TensorProto.INT8, TensorProto.UINT8, TensorProto.INT16, TensorProto.UINT16):
if len(data):
qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range)
quantized_data = quantize_nparray(qType, data, scale, zero_point)
return rmin, rmax, zero_point, scale, quantized_data
return _check_type(rmin, rmax, zero_point, scale, quantized_data)

raise ValueError(f"Unexpected value for qType={qType}.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,12 @@ def test_qdq_overrides_per_channel2(self):
self.assertEqual(wgt_zp.data_type, quant_type.tensor_type)
for index, (zp, scale) in enumerate(zip(wgt_zp.int32_data, wgt_sc.float_data)):
wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=reduce_ranges[index])
expected_zp, expected_scale = compute_scale_zp(rmin_vals[index], rmax_vals[index], wgt_qmin, wgt_qmax)
expected_zp, expected_scale = compute_scale_zp(
np.array(rmin_vals[index], dtype=np.float32),
np.array(rmax_vals[index], dtype=np.float32),
wgt_qmin,
wgt_qmax,
)
self.assertEqual(zp, expected_zp)
self.assertEqual(scale, np.float32(expected_scale))

Expand Down Expand Up @@ -499,4 +504,7 @@ def test_override_validation_bad_combination(self):


if __name__ == "__main__":
t = TestTensorQuantOverridesOption()
t.setUp()
t.test_qdq_overrides_per_channel1()
unittest.main()

0 comments on commit 1185de0

Please sign in to comment.