Skip to content

Commit

Permalink
Fix quantization dtypes after ORT PR #18043 (#881)
Browse files Browse the repository at this point in the history
## Describe your changes
PR microsoft/onnxruntime#18043 (onnxruntime)
extends onnxruntime quantization tools to support float16 weights. To do
so, it enforces scale and zerop_point to be strongly typed (as
`numpy.array(single_value, dtype=dtype)`). scale type should always be
the weight type, and zero_point type the quantized weight type. That
convention is checked all along the quantization tools to make sure
there is loss of information. This change was made to avoid adding new
arguments in many functions to carry this information.

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
  • Loading branch information
xadupre authored Jan 17, 2024
1 parent 4cfce3e commit aacc65c
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 52 deletions.
97 changes: 60 additions & 37 deletions olive/passes/onnx/vitis_ai/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,34 +235,48 @@ def compute_scale_zp_pof2s(rmin, rmax, qmin, qmax, symmetric=False):
if qmin > 0 or qmax < 0:
raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}")

# hasattr(rmin, "dtype") is True for onnxruntime>=1.17
rmin_dtype = rmin.dtype if hasattr(rmin, "dtype") else np.float32

# Adjust rmin and rmax such that 0 is included in the range. This is
# required to make sure zero can be represented by the quantization data
# type (i.e. to make sure qmin <= zero_point <= qmax)
rmin = min(rmin, 0)
rmax = max(rmax, 0)
rmin = np.minimum(rmin, np.array(0, dtype=rmin_dtype))
rmax = np.maximum(rmax, np.array(0, dtype=rmin_dtype))

# Ensure that rmax-rmin is less than or equal to sys.float_info.max
if rmin == float("-inf"):
rmin = -sys.float_info.max / 2
rmin = -np.finfo(rmin_dtype).max / 2
if rmax == float("inf"):
rmax = sys.float_info.max / 2
rmax = np.finfo(rmin_dtype).max / 2

if symmetric:
absmax = max(abs(rmin), abs(rmax))
absmax = np.maximum(np.abs(rmin), np.abs(rmax))
rmin = -absmax
rmax = +absmax

scale = (rmax - rmin) / float(qmax - qmin)
assert qmin <= qmax, f"qmin={rmin} > qmax={rmax}"
dr = np.array(rmax - rmin, dtype=np.float64)
dq = np.array(qmax, dtype=np.float64) - np.array(qmin, dtype=np.float64)
scale = np.array(dr / dq)

pos = scale2pos(scale)
pof2_scale = pos2scale(pos)

if pof2_scale < np.finfo(np.float32).tiny:
pof2_scale = 1.0
zero_point = 0
has_dtype = hasattr(qmin, "dtype") # True for onnxruntime>=1.17

if pof2_scale < np.finfo(rmax.dtype).tiny:
pof2_scale = np.array(1.0, dtype=rmax.dtype)
zero_point = np.array(0, dtype=qmin.dtype) if has_dtype else 0
else:
zero_point = round(qmin - rmin / pof2_scale)
pof2_scale = np.array(pof2_scale, dtype=rmin.dtype)
zero_point = (
np.array(np.round(qmin - rmin / pof2_scale), dtype=qmin.dtype)
if has_dtype
else int(round(qmin - rmin / pof2_scale))
)
if symmetric:
zero_point = 0
zero_point = np.array(0, dtype=qmin.dtype) if has_dtype else 0
return [zero_point, pof2_scale]


Expand All @@ -273,23 +287,29 @@ def quantize_zero_point(rmin, qmin, qmax, symmetric, scale):
rmin = min(rmin, 0)

if symmetric:
return 0
return np.array(0, dtype=qmin.dtype)

pof2_scale = scale
pof2_scale = np.array(scale, dtype=rmin.dtype)

if pof2_scale < np.finfo(np.float32).tiny:
pof2_scale = 1.0
zero_point = 0
has_dtype = hasattr(qmin, "dtype") # True for onnxruntime>=1.17

if pof2_scale < np.finfo(rmin.dtype).tiny:
pof2_scale = np.array(1.0, dtype=rmin.dtype)
zero_point = np.array(0, dtype=qmin.dtype) if has_dtype else 0
else:
zero_point = round(qmin - rmin / pof2_scale)
zero_point = (
np.array(round(qmin - rmin / pof2_scale), dtype=qmin.dtype)
if has_dtype
else int(round(qmin - rmin / pof2_scale))
)

return zero_point


def dequantize_data(data, scale, zero_point):
data = data.astype(np.float32)
deq_arr = (data - zero_point) * scale
return deq_arr.astype(np.float32)
return deq_arr.astype(scale.dtype)


def quantize_data_pof2s(data, qType, symmetric, reduce_range=False, method=PowerOfTwoMethod.NonOverflow, pos_range=5):
Expand All @@ -314,23 +334,26 @@ def quantize_data_pof2s(data, qType, symmetric, reduce_range=False, method=Power
- *S*: scale
- *z*: zero point
"""

rmin = 0
rmax = 0
zero_point = 0
scale = 1.0
if isinstance(data, np.ndarray):
rmin = data.min()
rmax = data.max()

elif isinstance(data, list) and len(data):
rmin = min(data)
rmax = max(data)
assert data.dtype in {
np.float16,
np.float32,
np.dtype("float16"),
np.dtype("float32"),
}, f"Unexpected dtype {data.dtype!r}"
rmin = data.min()
rmax = data.max()
assert rmin.dtype in {
np.float16,
np.float32,
np.dtype("float16"),
np.dtype("float32"),
}, f"Unexpected dtype {rmin.dtype!r}"
assert rmin.dtype == rmax.dtype

qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
zero_point, scale = compute_scale_zp_pof2s(rmin, rmax, qmin, qmax, symmetric)

quantized_data = quantize_nparray(qType, np.asarray(data), scale, zero_point)
quantized_data = quantize_nparray(qType, data, scale, zero_point)

if method == PowerOfTwoMethod.NonOverflow:
return rmin, rmax, zero_point, scale, quantized_data
Expand All @@ -341,20 +364,20 @@ def quantize_data_pof2s(data, qType, symmetric, reduce_range=False, method=Power
quantized_data_mse = quantized_data
diff_min = float("inf")
for i in range(pos_range):
new_scale = pos2scale(scale2pos(scale) + i)
rmin = min((qmin - zero_point) * new_scale, 0)
new_scale = np.array(pos2scale(scale2pos(scale) + i), dtype=rmin.dtype)
rmin = np.array(min((qmin - zero_point) * new_scale, 0), dtype=rmin.dtype)
new_zero_point = quantize_zero_point(rmin, qmin, qmax, symmetric, new_scale)

new_quantized_data = quantize_nparray(qType, np.asarray(data), new_scale, new_zero_point)
diff = np.sum((dequantize_data(new_quantized_data, new_scale, new_zero_point) - np.asarray(data)) ** 2)
new_quantized_data = quantize_nparray(qType, data, new_scale, new_zero_point)
diff = np.sum((dequantize_data(new_quantized_data, new_scale, new_zero_point) - data) ** 2)
if diff < diff_min:
diff_min = diff
scale_mse = new_scale
zp_mse = new_zero_point
quantized_data_mse = new_quantized_data

rmin_mse = (qmin - zp_mse) * scale_mse
rmax_mse = (qmax - zp_mse) * scale_mse
rmin_mse = np.array((qmin - zp_mse) * scale_mse, dtype=rmin.dtype)
rmax_mse = np.array((qmax - zp_mse) * scale_mse, dtype=rmax.dtype)

return rmin_mse, rmax_mse, zp_mse, scale_mse, quantized_data_mse

Expand Down
57 changes: 42 additions & 15 deletions olive/passes/onnx/vitis_ai/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
get_annotate_output_name,
get_qdq_to_remove,
get_relu_name,
is_ort_version_below_1_17,
quantize_data_pof2s,
remove_nodes,
vitis_quantize_data,
Expand Down Expand Up @@ -310,18 +311,21 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei
# Update packed weight, zero point, and scale initializers
weight_data = tensor_proto_to_array(weight)
_, _, zero_point, scale, q_weight_data = quantize_data_pof2s(
weight_data.flatten().tolist(),
weight_data.flatten(),
qType,
self.is_weight_symmetric,
self.reduce_range and reduce_range,
method=PowerOfTwoMethod.NonOverflow,
)
scale_initializer = onnx.helper.make_tensor(scale_name, onnx_proto.TensorProto.FLOAT, [], [scale])
zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], [zero_point])

scale_initializer = onnx.helper.make_tensor(
scale_name, onnx.helper.np_dtype_to_tensor_dtype(scale.dtype), [], [float(scale)]
)
zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], [int(zero_point)])
self.model.initializer().extend([scale_initializer, zero_initializer])

if not keep_float_weight:
q_weight_data = np.asarray(q_weight_data, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[qType]).reshape(
q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
weight.dims
)
q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
Expand Down Expand Up @@ -462,7 +466,7 @@ def calculate_quantization_params(self):

zero, scale = compute_scale_zp_pof2s(rmin, rmax, qmin, qmax, self.is_activation_symmetric)
if is_ort_version_below_1_17():
quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale)
quantization_params[tensor_name] = QuantizationParams(zero_point=int(zero), scale=float(scale))
else:
quantization_params[tensor_name] = QuantizationParams(
zero_point=zero, scale=scale, quant_type=self.activation_qType
Expand Down Expand Up @@ -545,22 +549,31 @@ def _is_tensor_quantizable(self, tensor_name):
"""
Check if tensor can be quantized
"""
return self._tensor_quantizable_data_type(tensor_name) is not None

def _tensor_quantizable_data_type(self, tensor_name):
"""
Return the tensor type if it is quantizable.
"""
weight = find_by_name(tensor_name, self.model.initializer())
if weight is not None:
if weight.data_type == onnx_proto.TensorProto.FLOAT:
return True
if weight.data_type in {onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16}:
return weight.data_type
elif tensor_name in self.value_infos.keys():
vi = self.value_infos[tensor_name]
if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type == TensorProto.FLOAT:
return True
if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in {
TensorProto.FLOAT,
TensorProto.FLOAT16,
}:
return vi.type.tensor_type.elem_type
else:
logger.warning(
"failed to infer the type of tensor: {}. Skip to quantize it. Please check if it is expected.".format(
tensor_name
)
)

return False
return None

def __quantize_tensor(self, tensor_name, quant_sharing_param=None, tensor_type=QDQQuantTensorType.ACTIVATION):
"""
Expand All @@ -571,13 +584,27 @@ def __quantize_tensor(self, tensor_name, quant_sharing_param=None, tensor_type=Q
quant_sharing_param: name of the tensor that provides quantization parameter
tensor_type: QDQQuantTensorType default ACTIVATION
"""
if self._is_tensor_quantizable(tensor_name):
data_type = self._tensor_quantizable_data_type(tensor_name)
if data_type is not None:
if quant_sharing_param:
self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
tensor_type=tensor_type, quant_para_provider=quant_sharing_param
)
try:
self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
tensor_type=tensor_type, quant_para_provider=quant_sharing_param, data_type=data_type
)
except TypeError:
# onnxruntime<1.17
self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
tensor_type=tensor_type,
quant_para_provider=quant_sharing_param,
)
elif tensor_name not in self.tensors_to_quantize:
self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type)
try:
self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
tensor_type=tensor_type, data_type=data_type
)
except TypeError:
# onnxruntime<1.17
self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type)

def quantize_activation_tensor(self, tensor_name, quant_sharing_param=None):
"""
Expand Down
28 changes: 28 additions & 0 deletions test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,31 @@ def test_vitis_ai_quantization_pass(tmp_path):
assert quantized_model.model_path.endswith(".onnx")
assert Path(quantized_model.model_path).exists()
assert Path(quantized_model.model_path).is_file()


def test_vitis_ai_quantization_pass_oveflow(tmp_path):
# setup
input_model = get_onnx_model()
dummy_user_script = tmp_path / "dummy_user_script.py"
dummy_data: Path = tmp_path / "dummy_data"
with dummy_user_script.open("w") as f:
f.write(" ")
if not dummy_data.exists():
dummy_data.mkdir()

config = {
"user_script": str(dummy_user_script),
"data_dir": str(dummy_data),
"dataloader_func": dummy_calibration_reader,
"calibrate_method": "NonOverflow",
}
output_folder = str(tmp_path / "vitis_ai_quantized")

# create VitisAIQuantization pass
p = create_pass_from_dict(VitisAIQuantization, config, disable_search=True)
# execute
quantized_model = p.run(input_model, None, output_folder)
# assert
assert quantized_model.model_path.endswith(".onnx")
assert Path(quantized_model.model_path).exists()
assert Path(quantized_model.model_path).is_file()

0 comments on commit aacc65c

Please sign in to comment.