diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index ebbd1643d43d1..5b9b70b816041 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -175,15 +175,21 @@ def _get_and_check_tensor_quant_overrides(self): if tensor_quant_overrides: initializer_names = self.model.get_initializer_name_set() value_info_names = set(self.value_infos.keys()) - tensor_names = initializer_names.union(value_info_names) keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} + keys_unsupported_for_initializers = {"rmax", "rmin", "scale", "zero_point"} for tensor_name, quant_overrides in tensor_quant_overrides.items(): - if tensor_name not in tensor_names: + if tensor_name not in initializer_names and tensor_name not in value_info_names: raise ValueError(f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model") has_scale = "scale" in quant_overrides has_zero_point = "zero_point" in quant_overrides + is_initializer = tensor_name in initializer_names + + if is_initializer: + for key in keys_unsupported_for_initializers: + if key in quant_overrides: + raise ValueError(f"Tensor override option '{key}' is invalid for initializers") if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): raise ValueError("Must provide both 'scale' and 'zero_point' if one of the overrides is provided") @@ -1024,11 +1030,6 @@ def quantize_initializer(self, weight, quant_type, reduce_range=False, keep_floa symmetric = quant_overrides.get("symmetric", self.is_weight_symmetric) reduce_range = quant_overrides.get("reduce_range", self.reduce_range and reduce_range) - unsupported_overrides = {"rmax", "rmin", "scale", "zero_point"} - for key in unsupported_overrides: - if key in quant_overrides: - raise ValueError(f"Tensor quantization override '{key}' is not supported for initializers") - q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight.name + "_zero_point" scale_name = weight.name + "_scale" @@ -1130,11 +1131,6 @@ def quantize_weight_per_channel( ) reduce_range = quant_overrides.get("reduce_range", self.reduce_range and reduce_range) - unsupported_overrides = {"rmax", "rmin", "scale", "zero_point"} - for key in unsupported_overrides: - if key in quant_overrides: - raise ValueError(f"Tensor quantization override '{key}' is not supported for per-channel weights") - weights = tensor_proto_to_array(initializer) channel_count = weights.shape[channel_axis] rmin_list = [] diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 5640eb2d0d519..1411f251eb4b3 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -185,7 +185,6 @@ def test_qdq_overrides1(self): new_bias_zp, new_bias_sc = compute_scale_zp(bias_rmin, bias_rmax, bias_qmin, bias_qmax, symmetric=True) self.assertEqual(bias_zp.int32_data[0], new_bias_zp) self.assertEqual(bias_sc.float_data[0], np.float32(new_bias_sc)) - print(new_bias_zp, new_bias_sc) def test_qdq_overrides2(self): """ @@ -236,7 +235,7 @@ def test_override_validation_scale_missing_zp(self): def test_override_validation_bad_combination(self): """ - Test that specifying a scale/zero_point with rmax should fail. + Test that specifying a scale/zero_point with rmax/rmin/symmetric/reduce_range should fail. """ with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -246,6 +245,66 @@ def test_override_validation_bad_combination(self): self.assertTrue("option 'rmax' is invalid with 'scale' and 'zero_point'" in str(context.exception)) + with self.assertRaises(ValueError) as context: + self.perform_qdq_quantization( + "model_validation.onnx", + tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "rmin": 10.0}}, + ) + + self.assertTrue("option 'rmin' is invalid with 'scale' and 'zero_point'" in str(context.exception)) + + with self.assertRaises(ValueError) as context: + self.perform_qdq_quantization( + "model_validation.onnx", + tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "symmetric": True}}, + ) + + self.assertTrue("option 'symmetric' is invalid with 'scale' and 'zero_point'" in str(context.exception)) + + with self.assertRaises(ValueError) as context: + self.perform_qdq_quantization( + "model_validation.onnx", + tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "reduce_range": True}}, + ) + + self.assertTrue("option 'reduce_range' is invalid with 'scale' and 'zero_point'" in str(context.exception)) + + def test_override_invalid_for_initializer(self): + """ + Test that specifying a scale, zero_point, rmin, rmax for initializers should fail. + """ + with self.assertRaises(ValueError) as context: + self.perform_qdq_quantization( + "model_validation.onnx", + tensor_quant_overrides={"WGT": {"scale": 0.0}}, + ) + + self.assertTrue("option 'scale' is invalid for initializers" in str(context.exception)) + + with self.assertRaises(ValueError) as context: + self.perform_qdq_quantization( + "model_validation.onnx", + tensor_quant_overrides={"WGT": {"zero_point": 0}}, + ) + + self.assertTrue("option 'zero_point' is invalid for initializers" in str(context.exception)) + + with self.assertRaises(ValueError) as context: + self.perform_qdq_quantization( + "model_validation.onnx", + tensor_quant_overrides={"WGT": {"rmin": 0.0}}, + ) + + self.assertTrue("option 'rmin' is invalid for initializers" in str(context.exception)) + + with self.assertRaises(ValueError) as context: + self.perform_qdq_quantization( + "model_validation.onnx", + tensor_quant_overrides={"WGT": {"rmax": 0.0}}, + ) + + self.assertTrue("option 'rmax' is invalid for initializers" in str(context.exception)) + if __name__ == "__main__": unittest.main()