Skip to content

Commit

Permalink
Consolidate Softmax range adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Mar 22, 2024
1 parent ae3f3a3 commit d4cd04d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def quantize_weight_per_channel_impl(

return q_weight_name, zp_name, scale_name

def adjust_tensor_ranges(self, softmax_0_to_1=False):
def adjust_tensor_ranges(self):
if self.tensors_range is None:
return

Expand All @@ -471,6 +471,6 @@ def adjust_tensor_ranges(self, softmax_0_to_1=False):
if not isinstance(td, TensorData):
raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.")
self.tensors_range[node.input[0]] = td
# Optionally, adjust Softmax to range from 0.0 to 1.0
elif node.op_type == "Softmax" and softmax_0_to_1:
# Adjust Softmax to range from 0.0 to 1.0
elif node.op_type == "Softmax":
self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0))
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def calculate_quantization_params(self):
if self.tensors_range is None:
return None

self.adjust_tensor_ranges(softmax_0_to_1=False)
self.adjust_tensor_ranges()

quantization_params = {}
for tensor_name in self.tensors_range:
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/quantization/qdq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]:
if self.tensors_range is None:
return {}

self.adjust_tensor_ranges(softmax_0_to_1=True) # Ensure Softmax ranges from 0.0 to 1.0 for QDQ models.
self.adjust_tensor_ranges()

quantization_params = {}
for tensor_name in self.tensors_range:
Expand Down

0 comments on commit d4cd04d

Please sign in to comment.