Skip to content

Commit

Permalink
Implement operator sinc | feat(torchlib) (#1228)
Browse files Browse the repository at this point in the history
Fixes #1221
  • Loading branch information
justinchuby authored Dec 15, 2023
1 parent 0c0dc1b commit cf3f998
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
13 changes: 10 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
"""
from __future__ import annotations

import math
from typing import Optional, Sequence

from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloatOrBFloat16
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

_MATH_PI = math.pi
IsScalar = common_ops.IsScalar


Expand Down Expand Up @@ -338,10 +340,15 @@ def aten_special_shifted_chebyshev_polynomial_w(x: TensorType, n: TensorType) ->
raise NotImplementedError()


def aten_special_sinc(self: TensorType) -> TensorType:
@torch_op(("aten::special_sinc", "aten::sinc"))
def aten_special_sinc(self: TFloat) -> TFloat:
"""special_sinc(Tensor self) -> Tensor"""

raise NotImplementedError()
# This computes the normalized sinc, where the input is multiplied by pi.
# https://pytorch.org/docs/stable/special.html#torch.special.sinc
pi_self = self * _MATH_PI

return op.Where(self == 0.0, op.CastLike(1, self), op.Sin(pi_self) / pi_self)


def aten_special_spherical_bessel_j0(x: TensorType) -> TensorType:
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,9 @@ def _where_input_wrangler(
TorchLibOpInfo("sigmoid", core_ops.aten_sigmoid),
TorchLibOpInfo("sign", core_ops.aten_sign),
TorchLibOpInfo("sin", core_ops.aten_sin),
TorchLibOpInfo(
"sinc", special_ops.aten_special_sinc, tolerance={torch.float16: (1e-2, 6e-4)}
),
TorchLibOpInfo("sinh", core_ops.aten_sinh),
TorchLibOpInfo(
"softmax",
Expand Down

0 comments on commit cf3f998

Please sign in to comment.