Skip to content

Commit

Permalink
fix wrong types
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Dec 25, 2023
1 parent aee75ff commit 63a8ea9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
39 changes: 21 additions & 18 deletions onnxruntime/python/tools/quantization/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,22 @@ def from_string(format):
}

ONNX_INT_TYPE_RANGE = {
onnx_proto.TensorProto.UINT8: (0, 255),
onnx_proto.TensorProto.INT8: (-128, 127),
onnx_proto.TensorProto.UINT16: (0, 65535),
onnx_proto.TensorProto.INT16: (-32768, 32767),
onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(255, dtype=numpy.uint8)),
onnx_proto.TensorProto.INT8: (numpy.array(-128, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65535, dtype=numpy.uint16)),
onnx_proto.TensorProto.INT16: (numpy.array(-32768, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
}

ONNX_INT_TYPE_SYMMETRIC_RANGE = {
onnx_proto.TensorProto.INT8: (-127, 127),
onnx_proto.TensorProto.INT16: (-32767, 32767),
onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
}

ONNX_INT_TYPE_REDUCED_RANGE = {
onnx_proto.TensorProto.UINT8: (0, 127),
onnx_proto.TensorProto.INT8: (-64, 64),
onnx_proto.TensorProto.UINT16: (0, 32767),
onnx_proto.TensorProto.INT16: (-16384, 16384),
onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(127, dtype=numpy.uint8)),
onnx_proto.TensorProto.INT8: (numpy.array(-64, dtype=numpy.int8), numpy.array(64, dtype=numpy.int8)),
onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(32767, dtype=numpy.uint16)),
onnx_proto.TensorProto.INT16: (numpy.array(-16384, dtype=numpy.uint16), numpy.array(16384, dtype=numpy.uint16)),
}


Expand Down Expand Up @@ -226,6 +226,7 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non
:return: zero and scale [z, s]
"""
assert rmin <= rmax, f"rmin={rmin} > rmax={rmax}"
if qmin > 0 or qmax < 0:
raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}")

Expand All @@ -244,12 +245,17 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non
rmin = -absmax
rmax = +absmax

scale = numpy.array(numpy.array(rmax - rmin, dtype=numpy.float64) / numpy.array(qmax - qmin, dtype=numpy.float64))
assert rmin <= rmax, f"rmin={rmin} > rmax={rmax}"
assert qmin <= qmax, f"qmin={rmin} > qmax={rmax}"
dr = numpy.array(rmax - rmin, dtype=numpy.float64)
dq = numpy.array(qmax, dtype=numpy.float64) - numpy.array(qmin, dtype=numpy.float64)
scale = numpy.array(dr / dq)
assert scale >= 0, "scale isse"
if scale < numpy.finfo(rmax.dtype).tiny:
scale = numpy.array(1.0, dtype=rmax.dtype)
zero_point = numpy.array(0, dtype=numpy.int32)
zero_point = numpy.array(0, dtype=qmin.dtype)
else:
zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=numpy.int32)
zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=qmin.dtype)
scale = scale.astype(rmax.dtype)

return [zero_point, scale]
Expand Down Expand Up @@ -325,20 +331,17 @@ def quantize_data(
if rmin_override is not None:
rmin = rmin_override
else:
rmin = 0.0
rmin = data.min() if len(data) else 0.0

if rmax_override is not None:
rmax = rmax_override
else:
rmax = 0.0
rmax = data.max() if len(data) else 0.0

rmin = numpy.array(rmin, dtype=data.dtype)
rmax = numpy.array(rmax, dtype=data.dtype)
zero_point = 0
scale = numpy.array(1.0, dtype=data.dtype)
if len(data):
rmin = data.min()
rmax = data.max()

if qType == TensorProto.FLOAT8E4M3FN:
if reduce_range:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,6 @@ def test_qdq_overrides3(self):

wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type)
new_wgt_zp, new_wgt_sc = compute_scale_zp(wgt_rmin, wgt_rmax, wgt_qmin, wgt_qmax)
print("****", [wgt_zp.data_type, wgt_rmin, wgt_rmax, wgt_qmin, wgt_qmax], [new_wgt_zp, new_wgt_sc])
# [2, array(0., dtype=float32), array(1., dtype=float32), 0, 255] [array(0, dtype=int32), array(0.00392157, dtype=float32)]
self.assertEqual(wgt_zp.int32_data[0], new_wgt_zp)
self.assertEqual(wgt_sc.float_data[0], np.float32(new_wgt_sc))

Expand Down Expand Up @@ -508,5 +506,5 @@ def test_override_validation_bad_combination(self):
if __name__ == "__main__":
t = TestTensorQuantOverridesOption()
t.setUp()
t.test_qdq_overrides_per_channel1()
t.test_qdq_default_per_channel()
unittest.main()

0 comments on commit 63a8ea9

Please sign in to comment.