diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 1411f251eb4b3..a8b2a406d0b5e 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -9,7 +9,6 @@ import numpy as np import onnx -from onnx import TensorProto, helper, numpy_helper from onnxruntime import quantization from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType @@ -23,9 +22,9 @@ def setUp(self): self.weight = np.array([[[-1.0, -2.0], [1.0, 2.0]], [[-0.5, -1.5], [0.5, 1.5]]], dtype=np.float32) self.bias = np.array([0.0, 1.0], dtype=np.float32) - self.default_act_qtype = TensorProto.UINT8 - self.default_wgt_qtype = TensorProto.UINT8 - self.default_bias_qtype = TensorProto.INT32 + self.default_act_qtype = onnx.TensorProto.UINT8 + self.default_wgt_qtype = onnx.TensorProto.UINT8 + self.default_bias_qtype = onnx.TensorProto.INT32 self.default_zp_scales = { "INP": (0, np.float32(0.0235294122248888)), @@ -44,16 +43,18 @@ def perform_qdq_quantization(self, output_model_name, tensor_quant_overrides=Non # | # (output) - inp = helper.make_tensor_value_info("INP", TensorProto.FLOAT, self.activations[0].shape) + inp = onnx.helper.make_tensor_value_info("INP", onnx.TensorProto.FLOAT, self.activations[0].shape) sigmoid_node = onnx.helper.make_node("Sigmoid", ["INP"], ["SIG_OUT"]) - out = helper.make_tensor_value_info("OUT", TensorProto.FLOAT, [None, None, None]) - wgt_init = numpy_helper.from_array(self.weight, "WGT") - bias_init = numpy_helper.from_array(self.bias, "BIAS") + out = onnx.helper.make_tensor_value_info("OUT", onnx.TensorProto.FLOAT, [None, None, None]) + wgt_init = onnx.numpy_helper.from_array(self.weight, "WGT") + bias_init = onnx.numpy_helper.from_array(self.bias, "BIAS") conv_node = onnx.helper.make_node("Conv", ["SIG_OUT", "WGT", "BIAS"], ["OUT"]) - graph = helper.make_graph([sigmoid_node, conv_node], "test", [inp], [out], initializer=[wgt_init, bias_init]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + graph = onnx.helper.make_graph( + [sigmoid_node, conv_node], "test", [inp], [out], initializer=[wgt_init, bias_init] + ) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 11)]) onnx.save(model, "model.onnx") # Quantize model @@ -219,7 +220,7 @@ def test_override_validation_nonexisting_tensor(self): tensor_quant_overrides={"NON_EXISTING": {"rmin": 0.0, "rmax": 0.5}}, ) - self.assertTrue("is not present in the model" in str(context.exception)) + self.assertIn("is not present in the model", str(context.exception)) def test_override_validation_scale_missing_zp(self): """ @@ -231,7 +232,7 @@ def test_override_validation_scale_missing_zp(self): tensor_quant_overrides={"SIG_OUT": {"scale": 0.0}}, ) - self.assertTrue("Must provide both 'scale' and 'zero_point'" in str(context.exception)) + self.assertIn("Must provide both 'scale' and 'zero_point'", str(context.exception)) def test_override_validation_bad_combination(self): """ @@ -243,7 +244,7 @@ def test_override_validation_bad_combination(self): tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "rmax": 10.0}}, ) - self.assertTrue("option 'rmax' is invalid with 'scale' and 'zero_point'" in str(context.exception)) + self.assertIn("option 'rmax' is invalid with 'scale' and 'zero_point'", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -251,7 +252,7 @@ def test_override_validation_bad_combination(self): tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "rmin": 10.0}}, ) - self.assertTrue("option 'rmin' is invalid with 'scale' and 'zero_point'" in str(context.exception)) + self.assertIn("option 'rmin' is invalid with 'scale' and 'zero_point'", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -259,7 +260,7 @@ def test_override_validation_bad_combination(self): tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "symmetric": True}}, ) - self.assertTrue("option 'symmetric' is invalid with 'scale' and 'zero_point'" in str(context.exception)) + self.assertIn("option 'symmetric' is invalid with 'scale' and 'zero_point'", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -267,7 +268,7 @@ def test_override_validation_bad_combination(self): tensor_quant_overrides={"SIG_OUT": {"scale": 0.0, "zero_point": 0, "reduce_range": True}}, ) - self.assertTrue("option 'reduce_range' is invalid with 'scale' and 'zero_point'" in str(context.exception)) + self.assertIn("option 'reduce_range' is invalid with 'scale' and 'zero_point'", str(context.exception)) def test_override_invalid_for_initializer(self): """ @@ -279,7 +280,7 @@ def test_override_invalid_for_initializer(self): tensor_quant_overrides={"WGT": {"scale": 0.0}}, ) - self.assertTrue("option 'scale' is invalid for initializers" in str(context.exception)) + self.assertIn("option 'scale' is invalid for initializers", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -287,7 +288,7 @@ def test_override_invalid_for_initializer(self): tensor_quant_overrides={"WGT": {"zero_point": 0}}, ) - self.assertTrue("option 'zero_point' is invalid for initializers" in str(context.exception)) + self.assertIn("option 'zero_point' is invalid for initializers", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -295,7 +296,7 @@ def test_override_invalid_for_initializer(self): tensor_quant_overrides={"WGT": {"rmin": 0.0}}, ) - self.assertTrue("option 'rmin' is invalid for initializers" in str(context.exception)) + self.assertIn("option 'rmin' is invalid for initializers", str(context.exception)) with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( @@ -303,7 +304,7 @@ def test_override_invalid_for_initializer(self): tensor_quant_overrides={"WGT": {"rmax": 0.0}}, ) - self.assertTrue("option 'rmax' is invalid for initializers" in str(context.exception)) + self.assertIn("option 'rmax' is invalid for initializers", str(context.exception)) if __name__ == "__main__":