Skip to content

Commit

Permalink
fix type issue
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Dec 25, 2023
1 parent 1185de0 commit aee75ff
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
21 changes: 17 additions & 4 deletions onnxruntime/python/tools/quantization/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def from_string(format):
}


def _check_type(*args):
def _check_type(*args, zero_point_index=-1):
new_args = []
for i, a in enumerate(args):
if numpy.issubdtype(type(a), numpy.number):
Expand All @@ -152,6 +152,10 @@ def _check_type(*args):
new_args.append(a)
else:
raise TypeError(f"arg {i} is not an array: {a}")
if i == zero_point_index:
v = new_args[-1]
if v.dtype == numpy.float32 or v.dtype == numpy.float16:
raise TypeError(f"zero_point cannot be {v.dtype}")
return tuple(new_args) if len(new_args) > 1 else new_args[0]


Expand Down Expand Up @@ -261,20 +265,29 @@ def compute_scale_zp_float8(element_type, std):
More details in notebook `quantization_fp8.ipynb
<https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/quantization_fp8.ipynb>`_.
"""
zp_dtype = None
if element_type not in FLOAT8_DISTRIBUTIONS:
if element_type == TensorProto.FLOAT8E4M3FN:
from onnx.numpy_helper import float8e4m3_to_float32
from onnx.reference.custom_element_types import float8e4m3fn

zp_dtype = float8e4m3fn
all_values = [float8e4m3_to_float32(i) for i in range(0, 256)]
values = numpy.array(
[f for f in all_values if not numpy.isnan(f) and not numpy.isinf(f)], dtype=numpy.float32
)
else:
raise ValueError(f"Quantization to element_type={element_type} not implemented.")
FLOAT8_DISTRIBUTIONS[element_type] = values
elif element_type == TensorProto.FLOAT8E4M3FN:
from onnx.reference.custom_element_types import float8e4m3fn

zp_dtype = float8e4m3fn

if zp_dtype is None:
raise TypeError(f"Unexpected element_type {element_type}.")
std_f8 = numpy.std(FLOAT8_DISTRIBUTIONS[element_type])
zero = numpy.array(0, dtype=std.dtype)
zero = numpy.array(0, dtype=zp_dtype)
scale = numpy.array(std / std_f8, dtype=std.dtype)
return [zero, scale]

Expand Down Expand Up @@ -339,14 +352,14 @@ def quantize_data(
f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], "
f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]."
)
return _check_type(rmin, rmax, zero_point, scale, quantized_data)
return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2)

if qType in (TensorProto.INT8, TensorProto.UINT8, TensorProto.INT16, TensorProto.UINT16):
if len(data):
qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range)
quantized_data = quantize_nparray(qType, data, scale, zero_point)
return _check_type(rmin, rmax, zero_point, scale, quantized_data)
return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2)

raise ValueError(f"Unexpected value for qType={qType}.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def test_qdq_overrides3(self):

wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type)
new_wgt_zp, new_wgt_sc = compute_scale_zp(wgt_rmin, wgt_rmax, wgt_qmin, wgt_qmax)
print("****", [wgt_zp.data_type, wgt_rmin, wgt_rmax, wgt_qmin, wgt_qmax], [new_wgt_zp, new_wgt_sc])
# [2, array(0., dtype=float32), array(1., dtype=float32), 0, 255] [array(0, dtype=int32), array(0.00392157, dtype=float32)]
self.assertEqual(wgt_zp.int32_data[0], new_wgt_zp)
self.assertEqual(wgt_sc.float_data[0], np.float32(new_wgt_sc))

Expand Down

0 comments on commit aee75ff

Please sign in to comment.