Skip to content

Commit

Permalink
Implement _fft_* ops | feat(torchlib) (#926)
Browse files Browse the repository at this point in the history
The change implements `_fft_c2c`, `_fft_c2r` and `_fft_r2c`. I extracted
the common logic to `_fftn_onnx`, with the hope that we will be able to
express this as a function when `DFT` supports dynamic axes:
onnx/onnx#5447

---------

Co-authored-by: Jay Zhang <[email protected]>
Co-authored-by: Jay Zhang <[email protected]>
  • Loading branch information
3 people authored Oct 26, 2023
1 parent 4d7ac4d commit 70843ef
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 1 deletion.
156 changes: 156 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,165 @@

from typing import Optional, Sequence

from onnxscript import INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
private=True,
complex=True,
)
def _fftn_onnx_normalization(
self,
transformed: TFloat,
normalization: int,
forward: bool,
dims: Sequence[int],
) -> TFloat:
# Obtain the total_sample_count (n) for normalization
self_shape = op.Shape(self)
total_sample_count = op.ReduceProd(self_shape[dims], keepdims=0)
total_sample_count = op.CastLike(total_sample_count, transformed)

# Normalize the result
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
if normalization == 1:
# "forward" - normalize by 1/n
if forward:
result = op.Div(transformed, op.Sqrt(total_sample_count))
else:
result = op.Mul(transformed, op.Sqrt(total_sample_count))
elif normalization == 2:
# "ortho" - normalize by 1/sqrt(n)
if forward:
result = op.Div(transformed, total_sample_count)
else:
result = transformed
else:
# "backward" - no normalization
if forward:
result = transformed
else:
result = op.Mul(transformed, total_sample_count)

return result


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
trace_only=True,
private=True,
complex=True,
)
def _fftn_onnx(
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
) -> TFloat:
"""Standard complex to complex or real to complex FFT (forward or backward).
This is a private shared function for implementing the various FFT functions.
Args:
self: The input tensor.
dims: The dimensions to apply FFT.
normalization: The normalization mode.
inverse: Whether to compute the inverse FFT.
onesided: Whether to compute the one-sided FFT, which retains only the
positive frequencies.
Returns:
The transformed tensor.
"""

# NOTE: trace_only because we need to process each dimension in a loop
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support

# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
# dimension at the beginning to represent the batch dimension.
transformed = op.Unsqueeze(self, axes=[0])

for dim_ in dims:
if dim_ >= 0:
# Add 1 to account for the batch dimension when counting axes from the left
dim_ = dim_ + 1
transformed = op.DFT(transformed, axis=dim_, inverse=inverse, onesided=onesided)
# Remove the batch dimension
transformed = op.Squeeze(transformed, axes=[0])

return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)


@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
def aten__fft_c2c(
self: TFloat, dim: Sequence[int], normalization: int, forward: bool
) -> TFloat:
"""_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
Standard complex to complex FFT (forward or backward).
"""

# NOTE: trace_only because we need to negate forward
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support

# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [d - 1 if d < 0 else d for d in dim]
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)


@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
def aten__fft_c2r(
self: TFloat,
dim: Sequence[int],
normalization: int,
last_dim_size: INT64, # pylint: disable=unused-argument
) -> TFloat:
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
Complex to real inverse FFT.
"""

# TODO(justinchuby): Figure out what last_dim_size does

self_rank = len(self.shape)
# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
# Take only the real part
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])

return op.Squeeze(real_part, axes=[-1])


@torch_op("aten::_fft_r2c", trace_only=True)
def aten__fft_r2c(
self: TFloat, dim: Sequence[int], normalization: int, onesided: bool
) -> TFloat:
"""_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
Real to complex forward FFT.
"""

# Add a new dimension at the end
signal = op.Unsqueeze(self, axes=[-1])
# No need to fill the imaginary part because ONNX DFT accepts real inputs
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs

self_rank = len(self.shape)
# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]

return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)


def aten_fft_fft(
self: TensorType, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None
) -> TensorType:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def torch_op(
trace_only: Whether the function should only be traced and not compiled.
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function supports complex.
complex: Whether the function expects complex-valued inputs.
"""
if registry is None:
registry = default_registry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def create_mismatch_report(
expected,
error: Exception,
) -> None:
torch.set_printoptions(threshold=sys.maxsize)

error_text = str(error)
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
short_test_name = test_name.split(".")[-1]
Expand Down
69 changes: 69 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,68 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
)


def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
del self # Unused
# Adapted from https://github.com/pytorch/pytorch/blob/01069ad4be449f376cf88a56d842b8eb50f6e9b6/torch/testing/_internal/opinfo/core.py#L2448C1-L2541C79
is_fp16_or_chalf = dtype in (torch.complex32, torch.half)
if not is_fp16_or_chalf:
nd_tensor = functools.partial(
opinfo_core.make_tensor,
(S, S + 1, S + 2),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
oned_tensor = functools.partial(
opinfo_core.make_tensor,
(31,),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
else:
low = None
high = None
shapes = ((2, 8, 9), (33,))

nd_tensor = functools.partial(
opinfo_core.make_tensor,
shapes[0],
device=device,
low=low,
high=high,
dtype=dtype,
requires_grad=requires_grad,
)
oned_tensor = functools.partial(
opinfo_core.make_tensor,
shapes[1],
device=device,
low=low,
high=high,
dtype=dtype,
requires_grad=requires_grad,
)

for normalization, forward in itertools.product((0, 1, 2), (True, False)):
# 1-D
yield opinfo_core.SampleInput(
oned_tensor(), dim=(0,), normalization=normalization, forward=forward
)
# N-D
for dim in [
(0,),
(1,),
(2,),
(1, 2),
(0, 1),
(0, 1, 2),
]:
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, forward=forward
)


def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):
del op_info # unused
del kwargs
Expand Down Expand Up @@ -1242,6 +1304,13 @@ def sample_inputs_scaled_dot_product_flash_attention(
# To avoid name duplication, it is possible to rename the OpInfo and specify
# the `op` field explicitly.
OP_DB: List[opinfo_core.OpInfo] = [
opinfo_core.OpInfo(
"ops.aten._fft_c2c",
aten_name="_fft_c2c",
dtypes=common_dtype.complex_types(),
sample_inputs_func=sample_inputs__fft_c2c,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._local_scalar_dense",
aten_name="_local_scalar_dense",
Expand Down
8 changes: 8 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_lib.ops import core as core_ops
from onnxscript.function_libs.torch_lib.ops import fft as fft_ops
from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops
from onnxscript.function_libs.torch_lib.ops import nn as nn_ops
from onnxscript.function_libs.torch_lib.ops import special as special_ops
Expand Down Expand Up @@ -450,6 +451,13 @@ def _where_input_wrangler(
# Ops to be tested for numerical consistency between onnx and pytorch
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
TorchLibOpInfo(
"ops.aten._fft_c2c", # Custom from extra_opinfo
fft_ops.aten__fft_c2c,
tolerance={torch.complex64: (3e-3, 1.8e-4)},
trace_only=True,
complex=True,
),
TorchLibOpInfo(
"ops.aten._local_scalar_dense",
core_ops.aten__local_scalar_dense,
Expand Down

0 comments on commit 70843ef

Please sign in to comment.