Skip to content

Commit

Permalink
[torchlib] Fix _log_softmax (#1789)
Browse files Browse the repository at this point in the history
Fix _log_softmax by moving the IsScalar call to the top so it can be
eagerly evaluated.

Also specify the squeeze axis explicitly to improve compatibility with
ORT: microsoft/onnxruntime#21661



This should fix a runtime error in XGLMForCausalLM
  • Loading branch information
justinchuby authored Aug 8, 2024
1 parent b1f4942 commit 9bae2b5
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,15 @@ def aten__log_softmax_half(
) -> FLOAT:
"""_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"""

# trace_only because we need to cast conditionally based on half_to_float
self_is_scalar = IsScalar(self)
if half_to_float:
self = op.Cast(self, to=FLOAT.dtype)

return aten__log_softmax(self, dim, half_to_float)
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.LogSoftmax(self, axis=dim)
if self_is_scalar:
result = op.Squeeze(result, op.Constant(value_ints=[0]))
return result


@torch_op("aten::_log_softmax", traceable=True)
Expand All @@ -101,7 +105,7 @@ def aten__log_softmax(
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.LogSoftmax(self, axis=dim)
if self_is_scalar: # squeeze to scalar due to input is scalar
if self_is_scalar:
result = op.Squeeze(result)
return result

Expand Down

0 comments on commit 9bae2b5

Please sign in to comment.