Skip to content

Commit

Permalink
Create tests for _fft_r2c | test(torchlib) (#1149)
Browse files Browse the repository at this point in the history
Update test for _fft_r2c op and fixed its implementation.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
fatcat-z and justinchuby authored Nov 28, 2023
1 parent 1567800 commit 49bd447
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 16 deletions.
17 changes: 12 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,18 @@ def _fftn_onnx(
# dimension at the beginning to represent the batch dimension.
transformed = op.Unsqueeze(self, axes=[0])

for dim_ in dims:
if dim_ >= 0:
# Add 1 to account for the batch dimension when counting axes from the left
dim_ = dim_ + 1
transformed = op.DFT(transformed, axis=dim_, inverse=inverse, onesided=onesided)
# Add 1 to account for the batch dimension when counting axes from the left
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]

for dim in new_dims[:-1]:
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)

# Torch computers one-sided FFT on the last dimension only.
if onesided:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
else:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)

# Remove the batch dimension
transformed = op.Squeeze(transformed, axes=[0])

Expand Down
56 changes: 46 additions & 10 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,21 +190,20 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
)


def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
del self # Unused
def _prepare_data_for_fft_ops(device, dtype, requires_grad=False):
# Adapted from https://github.com/pytorch/pytorch/blob/01069ad4be449f376cf88a56d842b8eb50f6e9b6/torch/testing/_internal/opinfo/core.py#L2448C1-L2541C79
is_fp16_or_chalf = dtype in (torch.complex32, torch.half)
if not is_fp16_or_chalf:
nd_tensor = functools.partial(
oned_tensor = functools.partial(
opinfo_core.make_tensor,
(S, S + 1, S + 2),
(31,),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
oned_tensor = functools.partial(
nd_tensor = functools.partial(
opinfo_core.make_tensor,
(31,),
(S, S + 1, S + 2),
device=device,
dtype=dtype,
requires_grad=requires_grad,
Expand All @@ -214,25 +213,32 @@ def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
high = None
shapes = ((2, 8, 9), (33,))

nd_tensor = functools.partial(
oned_tensor = functools.partial(
opinfo_core.make_tensor,
shapes[0],
shapes[1],
device=device,
low=low,
high=high,
dtype=dtype,
requires_grad=requires_grad,
)
oned_tensor = functools.partial(
nd_tensor = functools.partial(
opinfo_core.make_tensor,
shapes[1],
shapes[0],
device=device,
low=low,
high=high,
dtype=dtype,
requires_grad=requires_grad,
)

return oned_tensor, nd_tensor


def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
del self # Unused
oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad)

for normalization, forward in itertools.product((0, 1, 2), (True, False)):
# 1-D
yield opinfo_core.SampleInput(
Expand All @@ -252,6 +258,29 @@ def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
)


def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_):
del self # Unused
oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad)

for normalization, one_sided in itertools.product((0, 1, 2), (True, True)):
# 1-D
yield opinfo_core.SampleInput(
oned_tensor(), dim=(0,), normalization=normalization, onesided=one_sided
)
# N-D
for dim in [
(0,),
(1,),
(2,),
(1, 2),
(0, 1),
(0, 1, 2),
]:
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, onesided=one_sided
)


def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):
del op_info # unused
del kwargs
Expand Down Expand Up @@ -1358,6 +1387,13 @@ def sample_inputs__native_batch_norm_legit_no_stats(
sample_inputs_func=sample_inputs__fft_c2c,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._fft_r2c",
aten_name="_fft_r2c",
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs__fft_r2c,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._local_scalar_dense",
aten_name="_local_scalar_dense",
Expand Down
6 changes: 6 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,12 @@ def _where_input_wrangler(
trace_only=True,
complex=True,
),
TorchLibOpInfo(
"ops.aten._fft_r2c", # Custom from extra_opinfo
fft_ops.aten__fft_r2c,
tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)},
trace_only=True,
),
TorchLibOpInfo(
"ops.aten._local_scalar_dense",
core_ops.aten__local_scalar_dense,
Expand Down
2 changes: 1 addition & 1 deletion tools/onnx2script.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def convert2script(
"--verbose",
action="store_true",
help="Verbose mode, suppresses use of overloaded operators and inline constants",
default=False
default=False,
)

args = parser.parse_args()
Expand Down

0 comments on commit 49bd447

Please sign in to comment.