Skip to content

Commit

Permalink
fix ortmodule cast shape
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamLouly committed Jun 28, 2024
1 parent 587e92c commit cb57c82
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,12 @@ def int_or_float(value, allow_float_values):
# If casting into int has precision loss: keep float output
if allow_float_values and value % 1 != 0:
return value
return int(value)
# Handle NaN and inf values explicitly
if np.isinf(value):
# Use the maximum float value as the replacement
return np.finfo(np.float32).max
if np.isnan(value):
return 0

values = [self._try_get_value(node, i) for i in range(len(node.input))]
if all([v is not None for v in values]):
Expand Down

0 comments on commit cb57c82

Please sign in to comment.