Skip to content

Commit

Permalink
fix missing dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Dec 21, 2023
1 parent 63b84d1 commit a90f9f5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 16 deletions.
6 changes: 4 additions & 2 deletions onnxruntime/python/tools/quantization/operators/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ def quantize(self):
out_zero_point, out_scale = quant_overrides["zero_point"], quant_overrides["scale"]
else:
# Unless overridden by the user, force Softmax to range from 0.0 to 1.0
rmin = quant_overrides.get("rmin", 0.0)
rmax = quant_overrides.get("rmax", 1.0)
qparams = self.quantizer.quantization_params[output_name]
dtype = qparams.data["scale"].dtype
rmin = quant_overrides.get("rmin", np.array(0, dtype=dtype))
rmax = quant_overrides.get("rmax", np.array(1, dtype=dtype))
symmetric = quant_overrides.get("symmetric", self.quantizer.is_activation_symmetric)
reduce_range = quant_overrides.get("reduce_range", False)
qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_qdq_overrides1(self):
inp_zp, inp_sc, sig_out_zp, sig_out_sc, wgt_zp, wgt_sc, bias_zp, bias_sc, _, _ = self.perform_qdq_quantization(
"model_quant_overrides1.onnx",
tensor_quant_overrides={
"SIG_OUT": [{"scale": 1.0, "zero_point": 127}],
"SIG_OUT": [{"scale": np.array(1.0, dtype=np.float32), "zero_point": np.array(127, dtype=np.uint8)}],
"WGT": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}],
"BIAS": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}],
},
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_qdq_overrides3(self):
"""
Test overriding rmin and rmax for Conv weight
"""
wgt_rmin, wgt_rmax = 0.0, 1.0
wgt_rmin, wgt_rmax = np.array(0.0, dtype=np.float32), np.array(1.0, dtype=np.float32)
_, _, _, _, wgt_zp, wgt_sc, _, _, _, _ = self.perform_qdq_quantization(
"model_quant_overrides3.onnx",
tensor_quant_overrides={
Expand All @@ -298,7 +298,7 @@ def test_qdq_overrides4(self):
"""
Test overriding scale and zero_point for Conv weight
"""
wgt_zp_val, wgt_scale_val = 4, 0.5
wgt_zp_val, wgt_scale_val = np.array(4, dtype=np.float32), np.array(0.5, dtype=np.float32)
_, _, _, _, wgt_zp, wgt_sc, _, _, _, _ = self.perform_qdq_quantization(
"model_quant_overrides4.onnx",
tensor_quant_overrides={
Expand All @@ -315,7 +315,7 @@ def test_qdq_overrides_per_channel1(self):
"""
Test per-channel overriding of scale/zero_point for Conv weight and bias.
"""
zp_vals, scale_vals = [2, 4], [0.5, 0.2]
zp_vals, scale_vals = np.array([2, 4], dtype=np.float32), np.array([0.5, 0.2], dtype=np.float32)
(
_,
_,
Expand Down Expand Up @@ -380,14 +380,14 @@ def test_qdq_overrides_per_channel2(self):
"WGT": [
{
"quant_type": quant_type,
"rmin": rmin_vals[0],
"rmax": rmax_vals[0],
"rmin": np.array(rmin_vals[0], dtype=np.float32),
"rmax": np.array(rmax_vals[0], dtype=np.float32),
"reduce_range": reduce_ranges[0],
},
{
"quant_type": quant_type,
"rmin": rmin_vals[1],
"rmax": rmax_vals[1],
"rmin": np.array(rmin_vals[1], dtype=np.float32),
"rmax": np.array(rmax_vals[1], dtype=np.float32),
"reduce_range": reduce_ranges[1],
},
],
Expand All @@ -409,7 +409,9 @@ def test_override_validation_nonexisting_tensor(self):
with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"NON_EXISTING": [{"rmin": 0.0, "rmax": 0.5}]},
tensor_quant_overrides={
"NON_EXISTING": [{"rmin": np.array(0.0, dtype=np.float32), "rmax": np.array(0.5, dtype=np.float32)}]
},
)

self.assertIn("is not present in the model", str(context.exception))
Expand All @@ -421,7 +423,7 @@ def test_override_validation_scale_missing_zp(self):
with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0}]},
tensor_quant_overrides={"SIG_OUT": [{"scale": np.array(0.0, dtype=np.float32)}]},
)

self.assertIn("Must provide both 'scale' and 'zero_point'", str(context.exception))
Expand All @@ -433,31 +435,63 @@ def test_override_validation_bad_combination(self):
with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0, "zero_point": 0, "rmax": 10.0}]},
tensor_quant_overrides={
"SIG_OUT": [
{
"scale": np.array(0, dtype=np.float32),
"zero_point": np.array(0, dtype=np.int8),
"rmax": np.array(10.0, dtype=np.float32),
}
]
},
)

self.assertIn("option 'rmax' is invalid with 'scale' and 'zero_point'", str(context.exception))

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0, "zero_point": 0, "rmin": 10.0}]},
tensor_quant_overrides={
"SIG_OUT": [
{
"scale": np.array(0, dtype=np.float32),
"zero_point": np.array(0, dtype=np.int8),
"rmax": np.array(10.0, dtype=np.float32),
}
]
},
)

self.assertIn("option 'rmin' is invalid with 'scale' and 'zero_point'", str(context.exception))

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0, "zero_point": 0, "symmetric": True}]},
tensor_quant_overrides={
"SIG_OUT": [
{
"scale": np.array(0, dtype=np.float32),
"zero_point": np.array(0, dtype=np.int8),
"symmetric": True,
}
]
},
)

self.assertIn("option 'symmetric' is invalid with 'scale' and 'zero_point'", str(context.exception))

with self.assertRaises(ValueError) as context:
self.perform_qdq_quantization(
"model_validation.onnx",
tensor_quant_overrides={"SIG_OUT": [{"scale": 0.0, "zero_point": 0, "reduce_range": True}]},
tensor_quant_overrides={
"SIG_OUT": [
{
"scale": np.array(0, dtype=np.float32),
"zero_point": np.array(0, dtype=np.int8),
"reduce_range": True,
}
]
},
)

self.assertIn("option 'reduce_range' is invalid with 'scale' and 'zero_point'", str(context.exception))
Expand Down

0 comments on commit a90f9f5

Please sign in to comment.