diff --git a/olive/passes/onnx/vitis_ai/quant_utils.py b/olive/passes/onnx/vitis_ai/quant_utils.py index 1058e96a5..79c1c24f2 100644 --- a/olive/passes/onnx/vitis_ai/quant_utils.py +++ b/olive/passes/onnx/vitis_ai/quant_utils.py @@ -235,34 +235,48 @@ def compute_scale_zp_pof2s(rmin, rmax, qmin, qmax, symmetric=False): if qmin > 0 or qmax < 0: raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}") + # hasattr(rmin, "dtype") is True for onnxruntime>=1.17 + rmin_dtype = rmin.dtype if hasattr(rmin, "dtype") else np.float32 + # Adjust rmin and rmax such that 0 is included in the range. This is # required to make sure zero can be represented by the quantization data # type (i.e. to make sure qmin <= zero_point <= qmax) - rmin = min(rmin, 0) - rmax = max(rmax, 0) + rmin = np.minimum(rmin, np.array(0, dtype=rmin_dtype)) + rmax = np.maximum(rmax, np.array(0, dtype=rmin_dtype)) # Ensure that rmax-rmin is less than or equal to sys.float_info.max if rmin == float("-inf"): - rmin = -sys.float_info.max / 2 + rmin = -np.finfo(rmin_dtype).max / 2 if rmax == float("inf"): - rmax = sys.float_info.max / 2 + rmax = np.finfo(rmin_dtype).max / 2 if symmetric: - absmax = max(abs(rmin), abs(rmax)) + absmax = np.maximum(np.abs(rmin), np.abs(rmax)) rmin = -absmax rmax = +absmax - scale = (rmax - rmin) / float(qmax - qmin) + assert qmin <= qmax, f"qmin={rmin} > qmax={rmax}" + dr = np.array(rmax - rmin, dtype=np.float64) + dq = np.array(qmax, dtype=np.float64) - np.array(qmin, dtype=np.float64) + scale = np.array(dr / dq) + pos = scale2pos(scale) pof2_scale = pos2scale(pos) - if pof2_scale < np.finfo(np.float32).tiny: - pof2_scale = 1.0 - zero_point = 0 + has_dtype = hasattr(qmin, "dtype") # True for onnxruntime>=1.17 + + if pof2_scale < np.finfo(rmax.dtype).tiny: + pof2_scale = np.array(1.0, dtype=rmax.dtype) + zero_point = np.array(0, dtype=qmin.dtype) if has_dtype else 0 else: - zero_point = round(qmin - rmin / pof2_scale) + pof2_scale = np.array(pof2_scale, dtype=rmin.dtype) + zero_point = ( + np.array(np.round(qmin - rmin / pof2_scale), dtype=qmin.dtype) + if has_dtype + else int(round(qmin - rmin / pof2_scale)) + ) if symmetric: - zero_point = 0 + zero_point = np.array(0, dtype=qmin.dtype) if has_dtype else 0 return [zero_point, pof2_scale] @@ -273,15 +287,21 @@ def quantize_zero_point(rmin, qmin, qmax, symmetric, scale): rmin = min(rmin, 0) if symmetric: - return 0 + return np.array(0, dtype=qmin.dtype) - pof2_scale = scale + pof2_scale = np.array(scale, dtype=rmin.dtype) - if pof2_scale < np.finfo(np.float32).tiny: - pof2_scale = 1.0 - zero_point = 0 + has_dtype = hasattr(qmin, "dtype") # True for onnxruntime>=1.17 + + if pof2_scale < np.finfo(rmin.dtype).tiny: + pof2_scale = np.array(1.0, dtype=rmin.dtype) + zero_point = np.array(0, dtype=qmin.dtype) if has_dtype else 0 else: - zero_point = round(qmin - rmin / pof2_scale) + zero_point = ( + np.array(round(qmin - rmin / pof2_scale), dtype=qmin.dtype) + if has_dtype + else int(round(qmin - rmin / pof2_scale)) + ) return zero_point @@ -289,7 +309,7 @@ def quantize_zero_point(rmin, qmin, qmax, symmetric, scale): def dequantize_data(data, scale, zero_point): data = data.astype(np.float32) deq_arr = (data - zero_point) * scale - return deq_arr.astype(np.float32) + return deq_arr.astype(scale.dtype) def quantize_data_pof2s(data, qType, symmetric, reduce_range=False, method=PowerOfTwoMethod.NonOverflow, pos_range=5): @@ -314,23 +334,26 @@ def quantize_data_pof2s(data, qType, symmetric, reduce_range=False, method=Power - *S*: scale - *z*: zero point """ - - rmin = 0 - rmax = 0 - zero_point = 0 - scale = 1.0 - if isinstance(data, np.ndarray): - rmin = data.min() - rmax = data.max() - - elif isinstance(data, list) and len(data): - rmin = min(data) - rmax = max(data) + assert data.dtype in { + np.float16, + np.float32, + np.dtype("float16"), + np.dtype("float32"), + }, f"Unexpected dtype {data.dtype!r}" + rmin = data.min() + rmax = data.max() + assert rmin.dtype in { + np.float16, + np.float32, + np.dtype("float16"), + np.dtype("float32"), + }, f"Unexpected dtype {rmin.dtype!r}" + assert rmin.dtype == rmax.dtype qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) zero_point, scale = compute_scale_zp_pof2s(rmin, rmax, qmin, qmax, symmetric) - quantized_data = quantize_nparray(qType, np.asarray(data), scale, zero_point) + quantized_data = quantize_nparray(qType, data, scale, zero_point) if method == PowerOfTwoMethod.NonOverflow: return rmin, rmax, zero_point, scale, quantized_data @@ -341,20 +364,20 @@ def quantize_data_pof2s(data, qType, symmetric, reduce_range=False, method=Power quantized_data_mse = quantized_data diff_min = float("inf") for i in range(pos_range): - new_scale = pos2scale(scale2pos(scale) + i) - rmin = min((qmin - zero_point) * new_scale, 0) + new_scale = np.array(pos2scale(scale2pos(scale) + i), dtype=rmin.dtype) + rmin = np.array(min((qmin - zero_point) * new_scale, 0), dtype=rmin.dtype) new_zero_point = quantize_zero_point(rmin, qmin, qmax, symmetric, new_scale) - new_quantized_data = quantize_nparray(qType, np.asarray(data), new_scale, new_zero_point) - diff = np.sum((dequantize_data(new_quantized_data, new_scale, new_zero_point) - np.asarray(data)) ** 2) + new_quantized_data = quantize_nparray(qType, data, new_scale, new_zero_point) + diff = np.sum((dequantize_data(new_quantized_data, new_scale, new_zero_point) - data) ** 2) if diff < diff_min: diff_min = diff scale_mse = new_scale zp_mse = new_zero_point quantized_data_mse = new_quantized_data - rmin_mse = (qmin - zp_mse) * scale_mse - rmax_mse = (qmax - zp_mse) * scale_mse + rmin_mse = np.array((qmin - zp_mse) * scale_mse, dtype=rmin.dtype) + rmax_mse = np.array((qmax - zp_mse) * scale_mse, dtype=rmax.dtype) return rmin_mse, rmax_mse, zp_mse, scale_mse, quantized_data_mse diff --git a/olive/passes/onnx/vitis_ai/quantizer.py b/olive/passes/onnx/vitis_ai/quantizer.py index 35894c11f..7fb8f6dd5 100644 --- a/olive/passes/onnx/vitis_ai/quantizer.py +++ b/olive/passes/onnx/vitis_ai/quantizer.py @@ -43,6 +43,7 @@ get_annotate_output_name, get_qdq_to_remove, get_relu_name, + is_ort_version_below_1_17, quantize_data_pof2s, remove_nodes, vitis_quantize_data, @@ -310,18 +311,21 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei # Update packed weight, zero point, and scale initializers weight_data = tensor_proto_to_array(weight) _, _, zero_point, scale, q_weight_data = quantize_data_pof2s( - weight_data.flatten().tolist(), + weight_data.flatten(), qType, self.is_weight_symmetric, self.reduce_range and reduce_range, method=PowerOfTwoMethod.NonOverflow, ) - scale_initializer = onnx.helper.make_tensor(scale_name, onnx_proto.TensorProto.FLOAT, [], [scale]) - zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], [zero_point]) + + scale_initializer = onnx.helper.make_tensor( + scale_name, onnx.helper.np_dtype_to_tensor_dtype(scale.dtype), [], [float(scale)] + ) + zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], [int(zero_point)]) self.model.initializer().extend([scale_initializer, zero_initializer]) if not keep_float_weight: - q_weight_data = np.asarray(q_weight_data, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[qType]).reshape( + q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape( weight.dims ) q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name) @@ -462,7 +466,7 @@ def calculate_quantization_params(self): zero, scale = compute_scale_zp_pof2s(rmin, rmax, qmin, qmax, self.is_activation_symmetric) if is_ort_version_below_1_17(): - quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale) + quantization_params[tensor_name] = QuantizationParams(zero_point=int(zero), scale=float(scale)) else: quantization_params[tensor_name] = QuantizationParams( zero_point=zero, scale=scale, quant_type=self.activation_qType @@ -545,14 +549,23 @@ def _is_tensor_quantizable(self, tensor_name): """ Check if tensor can be quantized """ + return self._tensor_quantizable_data_type(tensor_name) is not None + + def _tensor_quantizable_data_type(self, tensor_name): + """ + Return the tensor type if it is quantizable. + """ weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: - if weight.data_type == onnx_proto.TensorProto.FLOAT: - return True + if weight.data_type in {onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16}: + return weight.data_type elif tensor_name in self.value_infos.keys(): vi = self.value_infos[tensor_name] - if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type == TensorProto.FLOAT: - return True + if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in { + TensorProto.FLOAT, + TensorProto.FLOAT16, + }: + return vi.type.tensor_type.elem_type else: logger.warning( "failed to infer the type of tensor: {}. Skip to quantize it. Please check if it is expected.".format( @@ -560,7 +573,7 @@ def _is_tensor_quantizable(self, tensor_name): ) ) - return False + return None def __quantize_tensor(self, tensor_name, quant_sharing_param=None, tensor_type=QDQQuantTensorType.ACTIVATION): """ @@ -571,13 +584,27 @@ def __quantize_tensor(self, tensor_name, quant_sharing_param=None, tensor_type=Q quant_sharing_param: name of the tensor that provides quantization parameter tensor_type: QDQQuantTensorType default ACTIVATION """ - if self._is_tensor_quantizable(tensor_name): + data_type = self._tensor_quantizable_data_type(tensor_name) + if data_type is not None: if quant_sharing_param: - self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo( - tensor_type=tensor_type, quant_para_provider=quant_sharing_param - ) + try: + self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo( + tensor_type=tensor_type, quant_para_provider=quant_sharing_param, data_type=data_type + ) + except TypeError: + # onnxruntime<1.17 + self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo( + tensor_type=tensor_type, + quant_para_provider=quant_sharing_param, + ) elif tensor_name not in self.tensors_to_quantize: - self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type) + try: + self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo( + tensor_type=tensor_type, data_type=data_type + ) + except TypeError: + # onnxruntime<1.17 + self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type) def quantize_activation_tensor(self, tensor_name, quant_sharing_param=None): """ diff --git a/test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py b/test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py index f3af7368d..0f7b5aa17 100644 --- a/test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py +++ b/test/unit_test/passes/vitis_ai/test_vitis_ai_quantization.py @@ -58,3 +58,31 @@ def test_vitis_ai_quantization_pass(tmp_path): assert quantized_model.model_path.endswith(".onnx") assert Path(quantized_model.model_path).exists() assert Path(quantized_model.model_path).is_file() + + +def test_vitis_ai_quantization_pass_oveflow(tmp_path): + # setup + input_model = get_onnx_model() + dummy_user_script = tmp_path / "dummy_user_script.py" + dummy_data: Path = tmp_path / "dummy_data" + with dummy_user_script.open("w") as f: + f.write(" ") + if not dummy_data.exists(): + dummy_data.mkdir() + + config = { + "user_script": str(dummy_user_script), + "data_dir": str(dummy_data), + "dataloader_func": dummy_calibration_reader, + "calibrate_method": "NonOverflow", + } + output_folder = str(tmp_path / "vitis_ai_quantized") + + # create VitisAIQuantization pass + p = create_pass_from_dict(VitisAIQuantization, config, disable_search=True) + # execute + quantized_model = p.run(input_model, None, output_folder) + # assert + assert quantized_model.model_path.endswith(".onnx") + assert Path(quantized_model.model_path).exists() + assert Path(quantized_model.model_path).is_file()