From aee77e63741e460bd998eec53ec20abc29e774fb Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 24 Jun 2024 16:27:51 +0200 Subject: [PATCH] Revert one breaking change from #1613 Signed-off-by: Xavier Dupre --- .../function_libs/torch_lib/ops/core.py | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e50489c38..ddd836c4a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -308,7 +308,7 @@ def aten_affine_grid_generator_backward( def aten_alias(self: TTensor) -> TTensor: """alias(Tensor(a) self) -> Tensor(a)""" - return self + return op.Identity(self) def aten_alias_copy(self: TensorType) -> TensorType: @@ -374,7 +374,7 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) self = aten_all_dim(self, d, keepdim=True) if not keepdim: self = op.Squeeze(self, list(dim)) - return self + return op.Identity(self) @torch_op("aten::all.dims", traceable=True) @@ -499,7 +499,7 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) self = aten_any_dim(self, d, keepdim=True) if not keepdim: self = op.Squeeze(self, list(dim)) - return self + return op.Identity(self) @torch_op("aten::any.dims", traceable=True) @@ -940,7 +940,7 @@ def aten_atleast_1d(self: TTensor) -> TTensor: if IsScalar(self): self = op.Reshape(self, op.Constant(value_ints=[1])) - return self + return op.Identity(self) @torch_op("aten::atleast_1d.Sequence") @@ -964,7 +964,7 @@ def aten_atleast_2d(self: TTensor) -> TTensor: if Rank(self) <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1])) - return self + return op.Identity(self) @torch_op("aten::atleast_2d.Sequence") @@ -991,7 +991,7 @@ def aten_atleast_3d(self: TTensor) -> TTensor: self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1])) elif rank == 2: self = op.Unsqueeze(self, op.Constant(value_ints=[-1])) - return self + return op.Identity(self) @torch_op("aten::atleast_3d.Sequence") @@ -1691,7 +1691,7 @@ def aten_clone( ) -> TTensor: """clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor""" - return self + return op.Identity(self) def aten_coalesce(self: TensorType) -> TensorType: @@ -1749,7 +1749,7 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat: def aten_conj(self: TTensor) -> TTensor: """conj(Tensor(a) self) -> Tensor(a)""" - return self + return op.Identity(self) @torch_op("aten::conj", complex=True, private=True) @@ -1825,7 +1825,7 @@ def aten_contiguous( """contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)""" # ONNX does not have the notion of memory_format. It is always treated as a no-op. - return self + return op.Identity(self) @torch_op("aten::conv1d", trace_only=True) @@ -2168,7 +2168,7 @@ def aten__to_copy( """_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor""" if dtype == -1: - return self + return op.Identity(self) else: return common_ops.cast_to(self, dtype=dtype) @@ -2493,7 +2493,7 @@ def aten_dense_dim(self: TensorType) -> int: def aten_detach(self: TensorType) -> TensorType: """detach(Tensor(a) self) -> Tensor(a)""" - return self + return op.Identity(self) def aten_detach_copy(self: TensorType) -> TensorType: @@ -4061,7 +4061,7 @@ def _aten_index_onnx( if _has_none_in_middle(indices): # If there is None in the middle, Advanced Indexing cannot decide where to put # the new dimensions. So it places them in the front, like GatherND does. - return self + return op.Identity(self) # When the indices are consecutive, Advanced Indexing will place the new dimensions # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes. @@ -4227,7 +4227,7 @@ def aten_index_put_bool( index = op.SequenceAt(indices, 0) # assume indices only have 1 element # FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it index_int = op.Cast(index, to=INT32.dtype) - # if all False, return self + # if all False, return op.Identity(self) if op.ReduceSum(index_int) == 0: result = self else: @@ -4700,7 +4700,7 @@ def aten_lift_fresh(self: TensorType) -> TensorType: def aten_lift_fresh_copy(self: TensorType) -> TensorType: """lift_fresh_copy(Tensor self) -> Tensor""" - return self + return op.Identity(self) def aten_linear_backward( @@ -7082,14 +7082,14 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: def aten_resolve_conj(self: TTensor) -> TTensor: """resolve_conj(Tensor(a) self) -> Tensor(a)""" - return self + return op.Identity(self) @torch_op("aten::resolve_neg", trace_only=True) def aten_resolve_neg(self: TTensor) -> TTensor: """resolve_neg(Tensor(a) self) -> Tensor(a)""" - return self + return op.Identity(self) def aten_result_type(tensor: TensorType, other: TensorType) -> int: @@ -7142,9 +7142,9 @@ def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor self_rank = len(self.shape) if self_rank == 0: - return self + return op.Identity(self) elif self.shape[0] == 0: # empty tensor - return self + return op.Identity(self) else: # NOTE: In pytorch, default value of dims is an empty list. if len(dims) == 0: # Empty sequence @@ -7166,10 +7166,10 @@ def aten_roll_complex(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> self_rank = len(self.shape) if self_rank == 1: - return self + return op.Identity(self) if self.shape[0] == 0: # empty tensor - return self + return op.Identity(self) self_real = op.Slice(self, [0], [1], axes=[-1]) self_imag = op.Slice(self, [1], [2], axes=[-1]) @@ -7819,7 +7819,7 @@ def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT6 if signal_rank == 1: # Add a batch dimension self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - return self, signal_rank + return op.Identity(self), signal_rank @torch_op("aten::stft", private=True) @@ -8768,7 +8768,7 @@ def aten_view_as_complex(self: TTensor) -> TTensor: # We always operate on the real representation of a complex number in torchlib # So this is a no-op - return self + return op.Identity(self) @torch_op("aten::view_as_complex_copy", trace_only=True) @@ -8777,7 +8777,7 @@ def aten_view_as_complex_copy(self: TTensor) -> TTensor: # We always operate on the real representation of a complex number in torchlib # So this is a no-op - return self + return op.Identity(self) @torch_op("aten::view_as_real", complex=True, trace_only=True) @@ -8786,7 +8786,7 @@ def aten_view_as_real(self: TTensor) -> TTensor: # We always operate on the real representation of a complex number in torchlib # So this is a no-op - return self + return op.Identity(self) @torch_op("aten::view_as_real_copy", complex=True, trace_only=True) @@ -8795,7 +8795,7 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: # We always operate on the real representation of a complex number in torchlib # So this is a no-op - return self + return op.Identity(self) @torch_op("aten::view_copy")