Skip to content

Commit

Permalink
Fix aten_copy dtype | fix(torchlib)
Browse files Browse the repository at this point in the history
Cast the output of `aten_copy` to `self`'s type.
  • Loading branch information
justinchuby authored Nov 17, 2023
1 parent 577f51a commit 498d1a3
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,13 +2048,12 @@ def aten_convolution_overrideable(
@torch_op("aten::copy")
def aten_copy(
self: TTensor,
src: TTensor,
src: TTensor2,
non_blocking: bool = False, # pylint: disable=unused-argument
) -> TTensor:
"""copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor"""

self = op.Identity(src)
return self
return op.CastLike(src, self)

Check warning on line 2056 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2056

Added line #L2056 was not covered by tests


@torch_op("aten::_to_copy", trace_only=True)
Expand Down

0 comments on commit 498d1a3

Please sign in to comment.