diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1b2ee6bea..cf2fae3db 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2048,13 +2048,12 @@ def aten_convolution_overrideable( @torch_op("aten::copy") def aten_copy( self: TTensor, - src: TTensor, + src: TTensor2, non_blocking: bool = False, # pylint: disable=unused-argument ) -> TTensor: """copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor""" - self = op.Identity(src) - return self + return op.CastLike(src, self) @torch_op("aten::_to_copy", trace_only=True) @@ -4059,12 +4058,9 @@ def aten_index_put_bool( # change array([F,F,T,F,F]) to array([2]) index = op.ArgMax(index_int) # assume index only have 1 True # change array([2]) to array([2,2,2,2,2]) - self_dim_1 = op.Gather(op.Shape(self), 1) - index_dim_0 = op.Gather(op.Shape(index), 0) - neg_1 = op.Constant(value_ints=[-1]) - shape = op.Concat( - op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0 - ) + self_dim_1 = op.Shape(self, start=1, end=2) + index_dim_0 = op.Shape(index, start=0, end=1) + shape = op.Concat(self_dim_1, index_dim_0, axis=0) new_ind = op.Expand(index, shape) new_ind_t = op.Transpose(new_ind) @@ -7513,7 +7509,7 @@ def _center_window_around_zeros_if_needed( window: TFloatOrBFloat16, n_fft: int ) -> TFloatOrBFloat16: # first dimension - n_win = op.Gather(op.Shape(window), 0) + n_win = op.Shape(window, start=0, end=1) # Center window around zeros if needed (required by ONNX's STFT) if n_win < n_fft: left = (n_fft - n_win) / 2