From a90f9f5f4dba15ca049ff6b542abbdf595ef816f Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 21 Dec 2023 13:47:49 +0100 Subject: [PATCH] fix missing dtype --- .../tools/quantization/operators/softmax.py | 6 +- .../test_tensor_quant_overrides_option.py | 62 ++++++++++++++----- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/onnxruntime/python/tools/quantization/operators/softmax.py b/onnxruntime/python/tools/quantization/operators/softmax.py index fe9b855a02fb4..61a69ab3649dd 100644 --- a/onnxruntime/python/tools/quantization/operators/softmax.py +++ b/onnxruntime/python/tools/quantization/operators/softmax.py @@ -98,8 +98,10 @@ def quantize(self): out_zero_point, out_scale = quant_overrides["zero_point"], quant_overrides["scale"] else: # Unless overridden by the user, force Softmax to range from 0.0 to 1.0 - rmin = quant_overrides.get("rmin", 0.0) - rmax = quant_overrides.get("rmax", 1.0) + qparams = self.quantizer.quantization_params[output_name] + dtype = qparams.data["scale"].dtype + rmin = quant_overrides.get("rmin", np.array(0, dtype=dtype)) + rmax = quant_overrides.get("rmax", np.array(1, dtype=dtype)) symmetric = quant_overrides.get("symmetric", self.quantizer.is_activation_symmetric) reduce_range = quant_overrides.get("reduce_range", False) qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 770f292286982..100ae7d8a22d1 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -215,7 +215,7 @@ def test_qdq_overrides1(self): inp_zp, inp_sc, sig_out_zp, sig_out_sc, wgt_zp, wgt_sc, bias_zp, bias_sc, _, _ = self.perform_qdq_quantization( "model_quant_overrides1.onnx", tensor_quant_overrides={ - "SIG_OUT": [{"scale": 1.0, "zero_point": 127}], + "SIG_OUT": [{"scale": np.array(1.0, dtype=np.float32), "zero_point": np.array(127, dtype=np.uint8)}], "WGT": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}], "BIAS": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}], }, @@ -276,7 +276,7 @@ def test_qdq_overrides3(self): """ Test overriding rmin and rmax for Conv weight """ - wgt_rmin, wgt_rmax = 0.0, 1.0 + wgt_rmin, wgt_rmax = np.array(0.0, dtype=np.float32), np.array(1.0, dtype=np.float32) _, _, _, _, wgt_zp, wgt_sc, _, _, _, _ = self.perform_qdq_quantization( "model_quant_overrides3.onnx", tensor_quant_overrides={ @@ -298,7 +298,7 @@ def test_qdq_overrides4(self): """ Test overriding scale and zero_point for Conv weight """ - wgt_zp_val, wgt_scale_val = 4, 0.5 + wgt_zp_val, wgt_scale_val = np.array(4, dtype=np.float32), np.array(0.5, dtype=np.float32) _, _, _, _, wgt_zp, wgt_sc, _, _, _, _ = self.perform_qdq_quantization( "model_quant_overrides4.onnx", tensor_quant_overrides={ @@ -315,7 +315,7 @@ def test_qdq_overrides_per_channel1(self): """ Test per-channel overriding of scale/zero_point for Conv weight and bias. """ - zp_vals, scale_vals = [2, 4], [0.5, 0.2] + zp_vals, scale_vals = np.array([2, 4], dtype=np.float32), np.array([0.5, 0.2], dtype=np.float32) ( _, _, @@ -380,14 +380,14 @@ def test_qdq_overrides_per_channel2(self): "WGT": [ { "quant_type": quant_type, - "rmin": rmin_vals[0], - "rmax": rmax_vals[0], + "rmin": np.array(rmin_vals[0], dtype=np.float32), + "rmax": np.array(rmax_vals[0], dtype=np.float32), "reduce_range": reduce_ranges[0], }, { "quant_type": quant_type, - "rmin": rmin_vals[1], - "rmax": rmax_vals[1], + "rmin": np.array(rmin_vals[1], dtype=np.float32), + "rmax": np.array(rmax_vals[1], dtype=np.float32), "reduce_range": reduce_ranges[1], }, ], @@ -409,7 +409,9 @@ def test_override_validation_nonexisting_tensor(self): with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( "model_validation.onnx", - tensor_quant_overrides={"NON_EXISTING": [{"rmin": 0.0, "rmax": 0.5}]}, + tensor_quant_overrides={ + "NON_EXISTING": [{"rmin": np.array(0.0, dtype=np.float32), "rmax": np.array(0.5, dtype=np.float32)}] + }, ) self.assertIn("is not present in the model", str(context.exception)) @@ -421,7 +423,7 @@ def test_override_validation_scale_missing_zp(self): with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( "model_validation.onnx", - tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0}]}, + tensor_quant_overrides={"SIG_OUT": [{"scale": np.array(0.0, dtype=np.float32)}]}, ) self.assertIn("Must provide both 'scale' and 'zero_point'", str(context.exception)) @@ -433,7 +435,15 @@ def test_override_validation_bad_combination(self): with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( "model_validation.onnx", - tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0, "zero_point": 0, "rmax": 10.0}]}, + tensor_quant_overrides={ + "SIG_OUT": [ + { + "scale": np.array(0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int8), + "rmax": np.array(10.0, dtype=np.float32), + } + ] + }, ) self.assertIn("option 'rmax' is invalid with 'scale' and 'zero_point'", str(context.exception)) @@ -441,7 +451,15 @@ def test_override_validation_bad_combination(self): with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( "model_validation.onnx", - tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0, "zero_point": 0, "rmin": 10.0}]}, + tensor_quant_overrides={ + "SIG_OUT": [ + { + "scale": np.array(0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int8), + "rmax": np.array(10.0, dtype=np.float32), + } + ] + }, ) self.assertIn("option 'rmin' is invalid with 'scale' and 'zero_point'", str(context.exception)) @@ -449,7 +467,15 @@ def test_override_validation_bad_combination(self): with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( "model_validation.onnx", - tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0, "zero_point": 0, "symmetric": True}]}, + tensor_quant_overrides={ + "SIG_OUT": [ + { + "scale": np.array(0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int8), + "symmetric": True, + } + ] + }, ) self.assertIn("option 'symmetric' is invalid with 'scale' and 'zero_point'", str(context.exception)) @@ -457,7 +483,15 @@ def test_override_validation_bad_combination(self): with self.assertRaises(ValueError) as context: self.perform_qdq_quantization( "model_validation.onnx", - tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0, "zero_point": 0, "reduce_range": True}]}, + tensor_quant_overrides={ + "SIG_OUT": [ + { + "scale": np.array(0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int8), + "reduce_range": True, + } + ] + }, ) self.assertIn("option 'reduce_range' is invalid with 'scale' and 'zero_point'", str(context.exception))