From 804ed0192ae8d3a86380f1fd29ef7f182dc751ef Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Nov 2023 17:31:25 -0800 Subject: [PATCH] Fix `aten_copy` dtype | fix(torchlib) (#1164) Cast the output of `aten_copy` to `self`'s type. Fixes https://github.com/microsoft/onnxscript/issues/1162 --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1b2ee6bea..8d61be90b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2048,13 +2048,12 @@ def aten_convolution_overrideable( @torch_op("aten::copy") def aten_copy( self: TTensor, - src: TTensor, + src: TTensor2, non_blocking: bool = False, # pylint: disable=unused-argument ) -> TTensor: """copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor""" - self = op.Identity(src) - return self + return op.CastLike(src, self) @torch_op("aten::_to_copy", trace_only=True)