From 9bae2b566ebbeb55ba6d27b368e456ccd4444175 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 7 Aug 2024 17:08:39 -0700 Subject: [PATCH] [torchlib] Fix _log_softmax (#1789) Fix _log_softmax by moving the IsScalar call to the top so it can be eagerly evaluated. Also specify the squeeze axis explicitly to improve compatibility with ORT: https://github.com/microsoft/onnxruntime/issues/21661 This should fix a runtime error in XGLMForCausalLM --- onnxscript/function_libs/torch_lib/ops/core.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4b851b369..d7e97e98d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -82,11 +82,15 @@ def aten__log_softmax_half( ) -> FLOAT: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - # trace_only because we need to cast conditionally based on half_to_float + self_is_scalar = IsScalar(self) if half_to_float: self = op.Cast(self, to=FLOAT.dtype) - - return aten__log_softmax(self, dim, half_to_float) + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.LogSoftmax(self, axis=dim) + if self_is_scalar: + result = op.Squeeze(result, op.Constant(value_ints=[0])) + return result @torch_op("aten::_log_softmax", traceable=True) @@ -101,7 +105,7 @@ def aten__log_softmax( if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.LogSoftmax(self, axis=dim) - if self_is_scalar: # squeeze to scalar due to input is scalar + if self_is_scalar: result = op.Squeeze(result) return result