diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index fe37cf3c87880..3f5e4e660003f 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -164,13 +164,15 @@ def __init__( augmented_model_path="augmented_model.onnx", symmetric=False, use_external_data_format=False, + per_channel=False, ): """ :param model_path: ONNX model to calibrate. It should be a model file path :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors. :param augmented_model_path: save augmented model to this path. :param symmetric: make range of tensor symmetric (central point is 0). - :param use_external_data_format: use external data format to store model which size is >= 2Gb + :param use_external_data_format: use external data format to store model which size is >= 2Gb. + :param per_channel: whether to compute ranges per each channel. """ if isinstance(model_path, str): self.model = load_model_with_shape_infer(Path(model_path)) @@ -183,6 +185,7 @@ def __init__( self.augmented_model_path = augmented_model_path self.symmetric = symmetric self.use_external_data_format = use_external_data_format + self.per_channel = per_channel self.augment_model = None self.infer_session = None @@ -274,6 +277,7 @@ def __init__( moving_average=False, averaging_constant=0.01, max_intermediate_outputs=None, + per_channel=False, ): """ :param model_path: ONNX model to calibrate. It is a model path @@ -284,6 +288,7 @@ def __init__( :param moving_average: compute the moving average of the minimum and maximum values instead of the global minimum and maximum. :param averaging_constant: constant smoothing factor to use when computing the moving average. :param max_intermediate_outputs: maximum number of intermediate outputs before an intermediate range is computed. + :param per_channel: whether to compute ranges per each channel. """ super().__init__( model_path, @@ -291,6 +296,7 @@ def __init__( augmented_model_path=augmented_model_path, symmetric=symmetric, use_external_data_format=use_external_data_format, + per_channel=per_channel, ) self.intermediate_outputs = [] self.calibrate_tensors_range = None @@ -310,9 +316,15 @@ def augment_graph(self): """ tensors, _ = self.select_tensors_to_calibrate(self.model) reshape_shape_name = str(uuid.uuid4()) - reshape_shape = numpy_helper.from_array(np.array([1], dtype=np.int64), reshape_shape_name) + reshape_shape = numpy_helper.from_array(np.array([-1], dtype=np.int64), reshape_shape_name) self.model.graph.initializer.append(reshape_shape) + def get_op_version(op_type, model): + for opset_import in model.opset_import: + if onnx.defs.has(op_type, opset_import.domain): + return opset_import.version + raise RuntimeError(f"Model does not contain a version for '{op_type}'.") + def add_reduce_min_max(tensor_name, reduce_op_name): # When doing ReduceMax/ReduceMin, ORT can't reduce on dim with value of 0 if 'keepdims' is false. # To make the code simple, we always let keepdims to be 1. @@ -332,7 +344,6 @@ def add_reduce_min_max(tensor_name, reduce_op_name): name=intermediate_output, ) - self.model.graph.node.extend([reduce_node, reshape_node]) value_infos = {vi.name: vi for vi in self.model.graph.value_info} value_infos.update({o.name: o for o in self.model.graph.output}) value_infos.update({i.name: i for i in self.model.graph.input}) @@ -343,7 +354,22 @@ def add_reduce_min_max(tensor_name, reduce_op_name): f"Unable to guess tensor type for tensor {tensor_name!r}, " f"running shape inference before quantization may resolve this issue." ) - self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, onnx_type, [1])) + + # Include axes in reduce_op when per_channel, always keeping axis=1 + if self.per_channel: + tensor_rank = len(value_infos[tensor_name].type.tensor_type.shape.dim) + reduced_axes = [0, *range(2, tensor_rank)] + # Depending on opset version, axes in ReduceMin/ReduceMax are in attribute or inputs + if get_op_version(reduce_op_name, self.model) < 18: + reduce_node.attribute.append(helper.make_attribute("axes", reduced_axes)) + else: + reduce_axes_name = str(uuid.uuid4()) + reduce_axes = numpy_helper.from_array(np.array(reduced_axes, dtype=np.int64), reduce_axes_name) + reduce_node.input.append(reduce_axes_name) + self.model.graph.initializer.append(reduce_axes) + + self.model.graph.node.extend([reduce_node, reshape_node]) + self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, onnx_type, [None])) for tensor in tensors: add_reduce_min_max(tensor, "ReduceMin") @@ -430,7 +456,7 @@ def compute_data(self) -> TensorsData: max_value_array = np.max(merged_added_output_dict[added_output_names[i + 1]], axis=0) if self.symmetric: - max_absolute_value = max(np.abs(min_value_array), np.abs(max_value_array)) + max_absolute_value = np.max([np.abs(min_value_array), np.abs(max_value_array)], axis=0) pairs.append(tuple([-max_absolute_value, max_absolute_value])) else: pairs.append(tuple([min_value_array, max_value_array])) @@ -1097,6 +1123,7 @@ def create_calibrator( moving_average = extra_options.get("moving_average", False) averaging_constant = extra_options.get("averaging_constant", 0.01) max_intermediate_outputs = extra_options.get("max_intermediate_outputs", None) + per_channel = extra_options.get("per_channel", False) calibrator = MinMaxCalibrater( model, op_types_to_calibrate, @@ -1106,6 +1133,7 @@ def create_calibrator( moving_average=moving_average, averaging_constant=averaging_constant, max_intermediate_outputs=max_intermediate_outputs, + per_channel=per_channel, ) elif calibrate_method == CalibrationMethod.Entropy: # default settings for entropy algorithm diff --git a/onnxruntime/test/python/quantization/test_calibration.py b/onnxruntime/test/python/quantization/test_calibration.py index 795447e8b79f6..b36ba45cdbb6c 100644 --- a/onnxruntime/test/python/quantization/test_calibration.py +++ b/onnxruntime/test/python/quantization/test_calibration.py @@ -275,7 +275,7 @@ def test_augment_graph_config_3(self): for output in added_outputs: self.assertTrue(output in augmented_model_outputs) - def construct_test_compute_data_model(self, test_model_path): + def construct_test_compute_data_model(self, test_model_path, opset_version=13): # (input) # | # Relu @@ -320,7 +320,7 @@ def construct_test_compute_data_model(self, test_model_path): graph.initializer.add().CopyFrom(b3) graph.initializer.add().CopyFrom(w5) graph.initializer.add().CopyFrom(b5) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset_version)]) onnx.save(model, test_model_path) def test_compute_data(self): @@ -456,6 +456,42 @@ def test_augment_graph_with_zero_value_dimension(self): for output in added_outputs: self.assertTrue(output in augmented_model_outputs) + def test_compute_data_per_channel(self): + test_model_path = Path(self._tmp_model_dir.name).joinpath("./test_model_6.onnx") + self.construct_test_compute_data_model(test_model_path.as_posix(), opset_version=18) + + augmented_model_path = Path(self._tmp_model_dir.name).joinpath("./augmented_test_model_6.onnx") + calibrater = create_calibrator( + test_model_path, augmented_model_path=augmented_model_path.as_posix(), extra_options={"per_channel": True} + ) + data_reader = TestDataReader() + calibrater.collect_data(data_reader) + tensors_range = calibrater.compute_data() + + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + infer_session = onnxruntime.InferenceSession( + test_model_path.as_posix(), + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) + data_reader.rewind() + rmin = np.array([np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], dtype=np.float32)[:, np.newaxis] + rmax = -1.0 * rmin + while True: + input = data_reader.get_next() + if not input: + break + output = np.asarray(infer_session.run(None, input)).reshape((6, 3, -1)) + rmin = np.minimum(rmin, np.amin(output, axis=-1)) + rmax = np.maximum(rmax, np.amax(output, axis=-1)) + + min_max_pairs = list(zip(rmin, rmax)) + output_names = [infer_session.get_outputs()[i].name for i in range(len(infer_session.get_outputs()))] + output_min_max_dict = dict(zip(output_names, min_max_pairs)) + for output_name in output_min_max_dict: + np.testing.assert_equal(output_min_max_dict[output_name], tensors_range[output_name].range_value) + if __name__ == "__main__": unittest.main()