diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 09c4d8f99..cf2fae3db 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2048,13 +2048,12 @@ def aten_convolution_overrideable( @torch_op("aten::copy") def aten_copy( self: TTensor, - src: TTensor, + src: TTensor2, non_blocking: bool = False, # pylint: disable=unused-argument ) -> TTensor: """copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor""" - self = op.Identity(src) - return self + return op.CastLike(src, self) @torch_op("aten::_to_copy", trace_only=True)