Skip to content

Commit

Permalink
[torchlib] Specify the squeeze axis
Browse files Browse the repository at this point in the history
Specify the squeeze axis explicitly to improve compatibility with ORT: microsoft/onnxruntime#21661
  • Loading branch information
justinchuby authored Aug 7, 2024
1 parent cf5ddd9 commit 5235261
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def aten__log_softmax(
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.LogSoftmax(self, axis=dim)
if self_is_scalar: # squeeze to scalar due to input is scalar
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))
return result


Expand Down Expand Up @@ -728,7 +728,7 @@ def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.ArgMax(self, keepdims=keepdim)
if self_is_scaler:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))

return result

Expand All @@ -743,7 +743,7 @@ def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fal

result = op.ArgMax(self, axis=dim, keepdims=keepdim)
if self_is_scaler:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))

return result

Expand All @@ -769,7 +769,7 @@ def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.ArgMin(self, keepdims=keepdim)
if self_is_scaler:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))

return result

Expand All @@ -784,7 +784,7 @@ def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fal

result = op.ArgMin(self, axis=dim, keepdims=keepdim)
if self_is_scaler:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))

return result

Expand Down Expand Up @@ -2821,7 +2821,7 @@ def aten_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat:
if IsScalar(input):
input = op.Reshape(input, op.Constant(value_ints=[-1]))
result, _ = op.Dropout(input, p, train)
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))
else:
result, _ = op.Dropout(input, p, train)

Expand Down Expand Up @@ -4250,7 +4250,7 @@ def aten_index_select(self: TTensor, dim: int, index: IntType) -> TTensor:
result = op.Gather(self, index, axis=dim)

if self_is_scalar:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))

return result

Expand Down Expand Up @@ -5094,7 +5094,7 @@ def aten_max(self: TReal) -> TReal:
result = op.ReduceMax(self, keepdims=False)

if self_is_scalar:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))

return result

Expand Down Expand Up @@ -5598,7 +5598,7 @@ def aten_multinomial(
log_input = op.Log(unsqueezed_input)
result = op.Multinomial(log_input, dtype=INT64.dtype, sample_size=num_samples)
if Rank(self) == 1:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))
return result


Expand Down Expand Up @@ -7415,7 +7415,7 @@ def aten_scatter_reduce(
src = op.Reshape(src, neg_1)
result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce)
if self_is_scalar:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))
return result


Expand Down Expand Up @@ -7667,7 +7667,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB
result = op.Cast(result, to=dtype)
if self_is_scalar:
# Convert to scalar when input is scalar
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))

return result

Expand All @@ -7682,7 +7682,7 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
result = op.Softmax(self, axis=dim)
if self_is_scalar:
# Convert to scalar when input is scalar
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))

return result

Expand Down Expand Up @@ -8101,7 +8101,7 @@ def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
result = op.ReduceSum(self, dim, keepdims=keepdim)

if self_is_scalar:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))
return result


Expand All @@ -8114,7 +8114,7 @@ def _aten_sum_dim_none(self: TReal, keepdim: bool = False) -> TReal:
result = op.ReduceSum(self, keepdims=keepdim)

if self_is_scalar:
result = op.Squeeze(result)
result = op.Squeeze(result, op.Constant(value_ints=[-1]))
return result


Expand Down

0 comments on commit 5235261

Please sign in to comment.