Skip to content

Commit

Permalink
Add tolerance limits
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jul 16, 2024
1 parent 4ad38ad commit 36f96d5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
6 changes: 3 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4666,11 +4666,11 @@ def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
def aten_lerp(self: TReal, end: TReal, weight: TReal) -> TReal:
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""

diff = op.Sub(end, self)
diff = op.CastLike(op.Sub(end, self), weight)
return op.Where(
op.Less(weight, 0.5),
op.Add(self, op.Mul(weight, diff)),
op.Sub(end, op.Mul(diff, op.Sub(1.0, weight)))
op.Add(self, op.CastLike(op.Mul(weight, diff), self)),
op.Sub(end, op.CastLike(op.Mul(diff, op.Sub(1.0, weight)), end))
)


Expand Down
12 changes: 10 additions & 2 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,11 @@ def _where_input_wrangler(
TorchLibOpInfo("log", core_ops.aten_log),
TorchLibOpInfo("le", core_ops.aten_le),
TorchLibOpInfo("le_bool", core_ops.aten_le_bool),
TorchLibOpInfo("lerp", core_ops.aten_lerp),
TorchLibOpInfo(
"lerp",
core_ops.aten_lerp,
tolerance={torch.float16: (2e-3, 2e-1)},
),
TorchLibOpInfo("log10", core_ops.aten_log10),
TorchLibOpInfo("log1p", core_ops.aten_log1p),
TorchLibOpInfo(
Expand Down Expand Up @@ -1021,7 +1025,11 @@ def _where_input_wrangler(
TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True),
TorchLibOpInfo("mul", core_ops.aten_mul),
TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True),
TorchLibOpInfo("mv", core_ops.aten_mv),
TorchLibOpInfo(
"mv",
core_ops.aten_mv,
tolerance={torch.float16: (3e-2, 1e-2)},
),
TorchLibOpInfo("narrow", core_ops.aten_narrow),
TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout),
TorchLibOpInfo("ne", core_ops.aten_ne),
Expand Down

0 comments on commit 36f96d5

Please sign in to comment.