Skip to content

Commit

Permalink
Fix casting in lerp
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jul 16, 2024
1 parent 36f96d5 commit 246efd8
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4663,14 +4663,15 @@ def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:


@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar"))
def aten_lerp(self: TReal, end: TReal, weight: TReal) -> TReal:
def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor:
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""

diff = op.CastLike(op.Sub(end, self), weight)
weight = op.CastLike(weight, self)
diff = op.Sub(end, self)
return op.Where(
op.Less(weight, 0.5),
op.Add(self, op.CastLike(op.Mul(weight, diff), self)),
op.Sub(end, op.CastLike(op.Mul(diff, op.Sub(1.0, weight)), end))
op.Add(self, op.Mul(weight, diff)),
op.Sub(end, op.Mul(diff, op.Sub(1.0, weight)))
)


Expand Down

0 comments on commit 246efd8

Please sign in to comment.