Skip to content

Commit

Permalink
Add expm1 operator
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jul 17, 2024
1 parent f8ee736 commit f12d737
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
5 changes: 3 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ def aten_special_expit(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_special_expm1(self: TensorType) -> TensorType:
@torch_op(("aten::expm1", "aten::special_expm"))
def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""special_expm1(Tensor self) -> Tensor"""

raise NotImplementedError()
return op.Sub(op.Exp(self), 1)


def aten_special_gammainc(self: TensorType, other: TensorType) -> TensorType:
Expand Down
3 changes: 3 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,9 @@ def _where_input_wrangler(
TorchLibOpInfo(
"erfc", special_ops.aten_special_erfc, tolerance={torch.float16: (1e-2, 2e-4)}
),
TorchLibOpInfo(
"expm1", special_ops.aten_special_expm1, tolerance={torch.float16: (1e-2, 2e-4)}
),
TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail(
reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223"
),
Expand Down

0 comments on commit f12d737

Please sign in to comment.