From 49bd447f28c76af2ba95f2625dc7f8e6f2b7d98e Mon Sep 17 00:00:00 2001 From: Jay Zhang <36183870+fatcat-z@users.noreply.github.com> Date: Tue, 28 Nov 2023 14:07:00 +0800 Subject: [PATCH] Create tests for _fft_r2c | test(torchlib) (#1149) Update test for _fft_r2c op and fixed its implementation. --------- Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/fft.py | 17 ++++-- .../function_libs/torch_lib/extra_opinfo.py | 56 +++++++++++++++---- .../function_libs/torch_lib/ops_test_data.py | 6 ++ tools/onnx2script.py | 2 +- 4 files changed, 65 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index cf8ec866b..f5f9b3bdb 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -95,11 +95,18 @@ def _fftn_onnx( # dimension at the beginning to represent the batch dimension. transformed = op.Unsqueeze(self, axes=[0]) - for dim_ in dims: - if dim_ >= 0: - # Add 1 to account for the batch dimension when counting axes from the left - dim_ = dim_ + 1 - transformed = op.DFT(transformed, axis=dim_, inverse=inverse, onesided=onesided) + # Add 1 to account for the batch dimension when counting axes from the left + new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] + + for dim in new_dims[:-1]: + transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False) + + # Torch computers one-sided FFT on the last dimension only. + if onesided: + transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True) + else: + transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False) + # Remove the batch dimension transformed = op.Squeeze(transformed, axes=[0]) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index e2e5b52e9..1ee6a4a92 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -190,21 +190,20 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs): ) -def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_): - del self # Unused +def _prepare_data_for_fft_ops(device, dtype, requires_grad=False): # Adapted from https://github.com/pytorch/pytorch/blob/01069ad4be449f376cf88a56d842b8eb50f6e9b6/torch/testing/_internal/opinfo/core.py#L2448C1-L2541C79 is_fp16_or_chalf = dtype in (torch.complex32, torch.half) if not is_fp16_or_chalf: - nd_tensor = functools.partial( + oned_tensor = functools.partial( opinfo_core.make_tensor, - (S, S + 1, S + 2), + (31,), device=device, dtype=dtype, requires_grad=requires_grad, ) - oned_tensor = functools.partial( + nd_tensor = functools.partial( opinfo_core.make_tensor, - (31,), + (S, S + 1, S + 2), device=device, dtype=dtype, requires_grad=requires_grad, @@ -214,18 +213,18 @@ def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_): high = None shapes = ((2, 8, 9), (33,)) - nd_tensor = functools.partial( + oned_tensor = functools.partial( opinfo_core.make_tensor, - shapes[0], + shapes[1], device=device, low=low, high=high, dtype=dtype, requires_grad=requires_grad, ) - oned_tensor = functools.partial( + nd_tensor = functools.partial( opinfo_core.make_tensor, - shapes[1], + shapes[0], device=device, low=low, high=high, @@ -233,6 +232,13 @@ def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_): requires_grad=requires_grad, ) + return oned_tensor, nd_tensor + + +def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_): + del self # Unused + oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) + for normalization, forward in itertools.product((0, 1, 2), (True, False)): # 1-D yield opinfo_core.SampleInput( @@ -252,6 +258,29 @@ def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_): ) +def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_): + del self # Unused + oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) + + for normalization, one_sided in itertools.product((0, 1, 2), (True, True)): + # 1-D + yield opinfo_core.SampleInput( + oned_tensor(), dim=(0,), normalization=normalization, onesided=one_sided + ) + # N-D + for dim in [ + (0,), + (1,), + (2,), + (1, 2), + (0, 1), + (0, 1, 2), + ]: + yield opinfo_core.SampleInput( + nd_tensor(), dim=dim, normalization=normalization, onesided=one_sided + ) + + def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs): del op_info # unused del kwargs @@ -1358,6 +1387,13 @@ def sample_inputs__native_batch_norm_legit_no_stats( sample_inputs_func=sample_inputs__fft_c2c, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._fft_r2c", + aten_name="_fft_r2c", + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs__fft_r2c, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten._local_scalar_dense", aten_name="_local_scalar_dense", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index a2a3bebd8..b6dc88ace 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -465,6 +465,12 @@ def _where_input_wrangler( trace_only=True, complex=True, ), + TorchLibOpInfo( + "ops.aten._fft_r2c", # Custom from extra_opinfo + fft_ops.aten__fft_r2c, + tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, + trace_only=True, + ), TorchLibOpInfo( "ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense, diff --git a/tools/onnx2script.py b/tools/onnx2script.py index 5d82e2a35..24556e755 100644 --- a/tools/onnx2script.py +++ b/tools/onnx2script.py @@ -55,7 +55,7 @@ def convert2script( "--verbose", action="store_true", help="Verbose mode, suppresses use of overloaded operators and inline constants", - default=False + default=False, ) args = parser.parse_args()