Skip to content

Commit

Permalink
[torchlib] Mark add/sub as trace_only (#1840)
Browse files Browse the repository at this point in the history
Simplify implementation
  • Loading branch information
justinchuby authored Sep 2, 2024
1 parent 0052b90 commit 2e45a32
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
"""torch.ops.aten operators under the `core` module.
Expand Down Expand Up @@ -167,12 +165,13 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"))
@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
return op.Add(self, other)


Expand Down Expand Up @@ -8112,13 +8111,14 @@ def aten_stft(
"aten::subtract.Tensor",
"aten::subtract.Scalar",
"_operator::sub",
)
),
trace_only=True,
)
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)

if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
return op.Sub(self, other)


Expand Down

0 comments on commit 2e45a32

Please sign in to comment.