Skip to content

Commit

Permalink
Add more validation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Nov 21, 2023
1 parent 1da0b78 commit 49b477a
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
20 changes: 8 additions & 12 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,21 @@ def _get_and_check_tensor_quant_overrides(self):
if tensor_quant_overrides:
initializer_names = self.model.get_initializer_name_set()
value_info_names = set(self.value_infos.keys())
tensor_names = initializer_names.union(value_info_names)
keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"}
keys_unsupported_for_initializers = {"rmax", "rmin", "scale", "zero_point"}

for tensor_name, quant_overrides in tensor_quant_overrides.items():
if tensor_name not in tensor_names:
if tensor_name not in initializer_names and tensor_name not in value_info_names:
raise ValueError(f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model")

has_scale = "scale" in quant_overrides
has_zero_point = "zero_point" in quant_overrides
is_initializer = tensor_name in initializer_names

if is_initializer:
for key in keys_unsupported_for_initializers:
if key in quant_overrides:
raise ValueError(f"Tensor override option '{key}' is invalid for initializers")

if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
raise ValueError("Must provide both 'scale' and 'zero_point' if one of the overrides is provided")
Expand Down Expand Up @@ -1024,11 +1030,6 @@ def quantize_initializer(self, weight, quant_type, reduce_range=False, keep_floa
symmetric = quant_overrides.get("symmetric", self.is_weight_symmetric)
reduce_range = quant_overrides.get("reduce_range", self.reduce_range and reduce_range)

unsupported_overrides = {"rmax", "rmin", "scale", "zero_point"}
for key in unsupported_overrides:
if key in quant_overrides:
raise ValueError(f"Tensor quantization override '{key}' is not supported for initializers")

q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
zp_name = weight.name + "_zero_point"
scale_name = weight.name + "_scale"
Expand Down Expand Up @@ -1130,11 +1131,6 @@ def quantize_weight_per_channel(
)
reduce_range = quant_overrides.get("reduce_range", self.reduce_range and reduce_range)

unsupported_overrides = {"rmax", "rmin", "scale", "zero_point"}
for key in unsupported_overrides:
if key in quant_overrides:
raise ValueError(f"Tensor quantization override '{key}' is not supported for per-channel weights")

weights = tensor_proto_to_array(initializer)
channel_count = weights.shape[channel_axis]
rmin_list = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def test_qdq_overrides1(self):
new_bias_zp, new_bias_sc = compute_scale_zp(bias_rmin, bias_rmax, bias_qmin, bias_qmax, symmetric=True)
self.assertEqual(bias_zp.int32_data[0], new_bias_zp)
self.assertEqual(bias_sc.float_data[0], np.float32(new_bias_sc))
print(new_bias_zp, new_bias_sc)

def test_qdq_overrides2(self):
"""
Expand Down Expand Up @@ -236,7 +235,7 @@ def test_override_validation_scale_missing_zp(self):

def test_override_validation_bad_combination(self):
"""
Test that specifying a scale/zero_point with rmax should fail.
Test that specifying a scale/zero_point with rmax/rmin/symmetric/reduce_range should fail.
"""
with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
Expand All @@ -246,6 +245,66 @@ def test_override_validation_bad_combination(self):

self.assertTrue("option 'rmax' is invalid with 'scale' and 'zero_point'" in str(context.exception))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "rmin": 10.0}},
)

self.assertTrue("option 'rmin' is invalid with 'scale' and 'zero_point'" in str(context.exception))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "symmetric": True}},
)

self.assertTrue("option 'symmetric' is invalid with 'scale' and 'zero_point'" in str(context.exception))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "reduce_range": True}},
)

self.assertTrue("option 'reduce_range' is invalid with 'scale' and 'zero_point'" in str(context.exception))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.

def test_override_invalid_for_initializer(self):
"""
Test that specifying a scale, zero_point, rmin, rmax for initializers should fail.
"""
with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"WGT": {"scale": 0.0}},
)

self.assertTrue("option 'scale' is invalid for initializers" in str(context.exception))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"WGT": {"zero_point": 0}},
)

self.assertTrue("option 'zero_point' is invalid for initializers" in str(context.exception))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"WGT": {"rmin": 0.0}},
)

self.assertTrue("option 'rmin' is invalid for initializers" in str(context.exception))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"WGT": {"rmax": 0.0}},
)

self.assertTrue("option 'rmax' is invalid for initializers" in str(context.exception))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a in b) cannot provide an informative message. Using assertIn(a, b) instead will give more informative messages.


if __name__ == "__main__":
unittest.main()

0 comments on commit 49b477a

Please sign in to comment.