diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4b851b369..d7e97e98d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -82,11 +82,15 @@ def aten__log_softmax_half( ) -> FLOAT: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - # trace_only because we need to cast conditionally based on half_to_float + self_is_scalar = IsScalar(self) if half_to_float: self = op.Cast(self, to=FLOAT.dtype) - - return aten__log_softmax(self, dim, half_to_float) + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.LogSoftmax(self, axis=dim) + if self_is_scalar: + result = op.Squeeze(result, op.Constant(value_ints=[0])) + return result @torch_op("aten::_log_softmax", traceable=True) @@ -101,7 +105,7 @@ def aten__log_softmax( if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.LogSoftmax(self, axis=dim) - if self_is_scalar: # squeeze to scalar due to input is scalar + if self_is_scalar: result = op.Squeeze(result) return result