diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index d32094651..db51a1495 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -11,14 +11,16 @@ """ from __future__ import annotations +import math from typing import Optional, Sequence from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op -from onnxscript.function_libs.torch_lib.tensor_typing import TFloatOrBFloat16 +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16 from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +_MATH_PI = math.pi IsScalar = common_ops.IsScalar @@ -338,10 +340,15 @@ def aten_special_shifted_chebyshev_polynomial_w(x: TensorType, n: TensorType) -> raise NotImplementedError() -def aten_special_sinc(self: TensorType) -> TensorType: +@torch_op(("aten::special_sinc", "aten::sinc")) +def aten_special_sinc(self: TFloat) -> TFloat: """special_sinc(Tensor self) -> Tensor""" - raise NotImplementedError() + # This computes the normalized sinc, where the input is multiplied by pi. + # https://pytorch.org/docs/stable/special.html#torch.special.sinc + pi_self = self * _MATH_PI + + return op.Where(self == 0.0, op.CastLike(1, self), op.Sin(pi_self) / pi_self) def aten_special_spherical_bessel_j0(x: TensorType) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 6aea6822d..415d6a1ed 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1428,6 +1428,9 @@ def _where_input_wrangler( TorchLibOpInfo("sigmoid", core_ops.aten_sigmoid), TorchLibOpInfo("sign", core_ops.aten_sign), TorchLibOpInfo("sin", core_ops.aten_sin), + TorchLibOpInfo( + "sinc", special_ops.aten_special_sinc, tolerance={torch.float16: (1e-2, 6e-4)} + ), TorchLibOpInfo("sinh", core_ops.aten_sinh), TorchLibOpInfo( "softmax",