diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e547eefd5..c17a83c32 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2053,7 +2053,7 @@ def aten_convolution_overrideable( raise NotImplementedError() -@torch_op(("aten::copy", "aten::_to_copy")) +@torch_op("aten::copy") def aten_copy( self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument ) -> TTensor: @@ -2063,6 +2063,20 @@ def aten_copy( return self +@torch_op("aten::_to_copy", trace_only=True) +def aten__to_copy( + self: TTensor, + dtype: int = -1, + non_blocking: bool = False, # pylint: disable=unused-argument +) -> TTensor: + """_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor""" + + if dtype == -1: + return op.Identity(self) + else: + return common_ops.cast_to(self, dtype=dtype) + + def aten_copysign(self: TensorType, other: TensorType) -> TensorType: """copysign.Tensor(Tensor self, Tensor other) -> Tensor"""