From da65bfff2aa73ee651aff2392521aa05e4d9db46 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 13 Oct 2023 10:22:57 -0400 Subject: [PATCH 1/2] Remove unused nn module --- pytorch_finufft/__init__.py | 4 +- pytorch_finufft/nn.py | 125 ------------------------------------ 2 files changed, 2 insertions(+), 127 deletions(-) delete mode 100644 pytorch_finufft/nn.py diff --git a/pytorch_finufft/__init__.py b/pytorch_finufft/__init__.py index 8a4bbf5..7ef2cd2 100644 --- a/pytorch_finufft/__init__.py +++ b/pytorch_finufft/__init__.py @@ -1,6 +1,6 @@ """Pytorch bindings for the FINUFFT Library""" -from . import functional, nn +from . import functional -__all__ = ["functional", "nn"] +__all__ = ["functional"] __version__ = "0.1.0" diff --git a/pytorch_finufft/nn.py b/pytorch_finufft/nn.py deleted file mode 100644 index 496d996..0000000 --- a/pytorch_finufft/nn.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import Optional, Union - -import torch - -import pytorch_finufft.functional as func - - -class Finufft1D1(torch.nn.Module): - def __init__( - self, - ): - """ - TODO - """ - super().__init__() - - def forward( - self, - points: torch.Tensor, - values: torch.Tensor, - output_shape: int, - out: Optional[torch.Tensor] = None, - fftshift: Optional[bool] = False, - **finufftkwargs: Union[int, float], - ) -> torch.Tensor: - """ - Evalutes the Type 1 NUFFT on the inputs. - - NOTE: By default, the ordering is set to match that of Pytorch, - Numpy, and Scipy's FFT APIs. To match the mode ordering - native to FINUFFT, set `fftshift = True`. - ``` - M-1 - f[k1] = SUM c[j] exp(+/-i k1 x(j)) - j=0 - - for -N1/2 <= k1 <= (N1-1)/2 - ``` - - Parameters - ---------- - points : torch.Tensor - The non-uniform points x_j. Valid only between -3pi and 3pi. - values : torch.Tensor - The source strengths c_j. - output_shape : int - Number of Fourier modes to use in the computation (which - coincides with the length of the resultant array). - out : Optional[torch.Tensor], optional - Array to populate with result in-place, by default None - fftshift : bool, optional - If True, centers the 0 mode in the resultant torch.Tensor; by default False - **finufftkwargs : Union[int, float] - Additional arguments will be passed into FINUFFT. See - https://finufft.readthedocs.io/en/latest/python.html - - - Returns - ------- - torch.Tensor - The resultant array f[k] - """ - return func.finufft1D1.apply( - points, values, output_shape, out, fftshift, **finufftkwargs - ) - - -class Finufft1D2(torch.nn.Module): - def __init__( - self, - ): - """ - TODO - """ - super().__init__() - - # TODO - - def forward( - self, - points: torch.Tensor, - values: torch.Tensor, - output_shape: int, - out: Optional[torch.Tensor] = None, - fftshift: Optional[bool] = False, - **finufftkwargs: Union[int, float], - ) -> torch.Tensor: - """ - Evalutes the Type 1 NUFFT on the inputs. - - NOTE: By default, the ordering is set to match that of Pytorch, - Numpy, and Scipy's FFT APIs. To match the mode ordering - native to FINUFFT, set `fftshift = True`. - ``` - M-1 - f[k1] = SUM c[j] exp(+/-i k1 x(j)) - j=0 - - for -N1/2 <= k1 <= (N1-1)/2 - ``` - - Parameters - ---------- - points : torch.Tensor - The non-uniform points x_j. Valid only between -3pi and 3pi. - values : torch.Tensor - The source strengths c_j. - output_shape : int - Number of Fourier modes to use in the computation - out : Optional[torch.Tensor], optional - _description_, by default None - fftshift : bool, optional - _description_, by default False - **finufftkwargs : Union[int, float] - Additional arguments will be passed into FINUFFT. See - https://finufft.readthedocs.io/en/latest/python.html - - Returns - ------- - torch.Tensor - _description_ - """ - return func.finufft1D2.apply( - points, values, output_shape, out, fftshift, **finufftkwargs - ) From 3a6ea1736097d9be7feb676e52bc546ef60ddab0 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 13 Oct 2023 10:32:19 -0400 Subject: [PATCH 2/2] Remove old t2 functions, all tests/lints passing --- pytorch_finufft/checks.py | 4 +- pytorch_finufft/functional.py | 723 ++---------------------------- tests/test_1d/test_backward_1d.py | 77 +--- tests/test_1d/test_forward_1d.py | 36 +- tests/test_2d/test_backward_2d.py | 106 +---- tests/test_2d/test_forward_2d.py | 59 +-- tests/test_3d/test_backward_3d.py | 149 +----- tests/test_3d/test_forward_3d.py | 42 +- tests/test_errors.py | 2 +- 9 files changed, 55 insertions(+), 1143 deletions(-) diff --git a/pytorch_finufft/checks.py b/pytorch_finufft/checks.py index 1dd439a..95431a4 100644 --- a/pytorch_finufft/checks.py +++ b/pytorch_finufft/checks.py @@ -87,9 +87,7 @@ def check_output_shape(ndim: int, output_shape: Union[int, Tuple[int, ...]]) -> _COORD_CHAR_TABLE = "xyz" -def _type2_checks( - points_tuple: Tuple[torch.Tensor, ...], targets: torch.Tensor -) -> None: +def _type2_checks(points_tuple: torch.Tensor, targets: torch.Tensor) -> None: """ Performs all type, precision, size, device, ... checks for the type 2 FINUFFT diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index affb51a..0109b66 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union -import numpy as np import torch try: @@ -29,667 +28,7 @@ import pytorch_finufft.checks as checks -############################################################################### -# 1d Functions -############################################################################### - - -class finufft1D2(torch.autograd.Function): - """ - FINUFFT 1d Problem type 2 - """ - - @staticmethod - def forward( - ctx: Any, - points: torch.Tensor, - targets: torch.Tensor, - out: Optional[torch.Tensor] = None, - fftshift: bool = False, - finufftkwargs: Dict[str, Union[int, float]] = {}, - ) -> torch.Tensor: - """ - Evaluates the Type 2 NUFFT on the inputs. - - NOTE: By default, the ordering is set to match that of Pytorch, - Numpy, and Scipy's FFT APIs. To match the mode ordering - native to FINUFFT, set fftshift=True. - - ``` - c[j] = SUM f[k1] exp(+/-i k1 x(j)) - k1 - - for j = 0, ..., M-1, where the sum is over -N1/2 <= k1 <= (N1-1)/2 - ``` - - Parameters - ---------- - ctx : Any - PyTorch context object - points : torch.Tensor - The non-uniform points x_j. Valid only between -3pi and 3pi. - targets : torch.Tensor - The target Fourier mode coefficients f_k. - out : Optional[torch.Tensor], optional - Array to take the result in-place, by default None - fftshift : bool - If True centers the 0 mode in the resultant array, by default False - finufftkwargs : Dict[str, Union[int, float]] - Additional arguments will be passed into FINUFFT. See - https://finufft.readthedocs.io/en/latest/python.html. By default - an empty dictionary - - Returns - ------- - torch.Tensor - The resultant array c[j] - - Raises - ------ - ValueError - In the case that the mode ordering is double-specified with both - fftshift and the kwarg modeord (only one should be provided). - """ - if out is not None: - print("In-place results are not yet implemented") - - checks._type2_checks((points,), targets) - - finufftkwargs = {k: v for k, v in finufftkwargs.items()} - _mode_ordering = finufftkwargs.pop("modeord", 1) - _i_sign = finufftkwargs.pop("isign", 1) - - if fftshift: - if _mode_ordering != 1: - raise ValueError( - "Double specification of ordering; only one of fftshift " - "and modeord should be provided" - ) - _mode_ordering = 0 - - ctx.isign = _i_sign - ctx.mode_ordering = _mode_ordering - ctx.fftshift = fftshift - ctx.finufftkwargs = finufftkwargs - - ctx.save_for_backward(points, targets) - - finufft_out = finufft.nufft1d2( - points.data.numpy(), - targets.data.numpy(), - modeord=_mode_ordering, - isign=_i_sign, - **finufftkwargs, - ) - - return torch.from_numpy(finufft_out) - - @staticmethod - def backward( - ctx: Any, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: - """ - Implements derivatives wrt. each argument in the forward method - - Parameters - ---------- - ctx : Any - PyTorch context object - grad_output : torch.Tensor - Backpass gradient output - - Returns - ------- - Tuple[Union[torch.Tensor, None], ...] - Tuple of derivatives wrt. each argument in the forward method - """ - _i_sign = ctx.isign - _mode_ordering = ctx.mode_ordering - finufftkwargs = ctx.finufftkwargs - - points, targets = ctx.saved_tensors - - grad_points = grad_targets = None - - if ctx.needs_input_grad[0]: - # w.r.t. the points x_j - - k_ramp = torch.arange(0, targets.shape[-1], dtype=points.dtype) - ( - targets.shape[-1] // 2 - ) - if _mode_ordering != 0: - k_ramp = torch.fft.ifftshift(k_ramp) - - # TODO analytically work out if we can simplify this *1j, - # the below conj, and below *values - ramped_targets = k_ramp * targets * 1j * _i_sign - - np_points = (points.data).numpy() - np_ramped_targets = (ramped_targets.data).numpy() - - grad_points = torch.from_numpy( - finufft.nufft1d2( - np_points, - np_ramped_targets, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, - ) - ).to(targets.dtype) - - grad_points = grad_points.conj() - grad_points *= grad_output - - grad_points = grad_points.real - - if ctx.needs_input_grad[1]: - np_points = points.data.numpy() - np_grad_output = grad_output.data.numpy() - - grad_targets = torch.from_numpy( - finufft.nufft1d1( - np_points, - np_grad_output, - len(targets), - modeord=_mode_ordering, - isign=(-1 * _i_sign), - **finufftkwargs, - ) - ) - - return grad_points, grad_targets, None, None, None - - -############################################################################### -# 2d Functions -############################################################################### - - -class finufft2D2(torch.autograd.Function): - """ - FINUFFT 2D problem type 2 - """ - - @staticmethod - def forward( - ctx: Any, - points_x: torch.Tensor, - points_y: torch.Tensor, - targets: torch.Tensor, - out: Optional[torch.Tensor] = None, - fftshift: bool = False, - finufftkwargs: Dict[str, Union[int, float]] = {}, - ) -> torch.Tensor: - """ - Evaluates the Type 2 NUFFT on the inputs. - - NOTE: By default, the ordering is set to match that of Pytorch, - Numpy, and Scipy's FFT APIs. To match the mode ordering - native to FINUFFT, set fftshift=True. - - ``` - c[j] = SUM f[k1, k2] exp(+/-i (k1 x(j) + k2 y(j))) - k1, k2 - - for j = 0, ..., M-1, where the sum is over -N1/2 <= k1 <= (N1-1)/2, - -N2/2 <= k2 <= (N2-1)/2 - ``` - - Parameters - ---------- - ctx : Any - Pytorch context objecy - points_x : torch.Tensor - The non-uniform points x_j - points_y : torch.Tensor - The non-uniform points y_j - targets : torch.Tensor - The target Fourier mode coefficients f[k1, k2] - out : Optional[torch.Tensor], optional - Array to take the result in-place, by default None - fftshift : bool - If True centers the 0 mode in the resultant torch.Tensor, by default False - finufftkwargs : Dict[str, Union[int, float]] - Additional arguments will be passed into FINUFFT. See - https://finufft.readthedocs.io/en/latest/python.html. By default - an empty dictionary - - Returns - ------- - torch.Tensor - The resultant array c[j] - - Raises - ------ - ValueError - In the case of conflicting specification of the wave-mode ordering. - """ - - if out is not None: - print("In-place results are not yet implemented") - - # TODO -- extend checks to 2d - checks._type2_checks((points_x, points_y), targets) - - finufftkwargs = {k: v for k, v in finufftkwargs.items()} - _mode_ordering = finufftkwargs.pop("modeord", 1) - _i_sign = finufftkwargs.pop("isign", 1) - - if fftshift: - if _mode_ordering != 1: - raise ValueError( - "Double specification of ordering; only one of fftshift and " - "modeord should be provided." - ) - _mode_ordering = 0 - - ctx.isign = _i_sign - ctx.mode_ordering = _mode_ordering - ctx.fftshift = fftshift - ctx.finufftkwargs = finufftkwargs - - ctx.save_for_backward(points_x, points_y, targets) - - finufft_out = finufft.nufft2d2( - points_x.data.numpy(), - points_y.data.numpy(), - targets.data.numpy(), - modeord=_mode_ordering, - isign=_i_sign, - **finufftkwargs, - ) - - return torch.from_numpy(finufft_out) - - @staticmethod - def backward( - ctx: Any, grad_output: torch.Tensor - ) -> Tuple[ - Union[torch.Tensor, None], - Union[torch.Tensor, None], - Union[torch.Tensor, None], - None, - None, - None, - ]: - """ - Implements derivatives wrt. each argument in the forward method. - - Parameters - ---------- - ctx : Any - Pytorch context object - grad_output : torch.Tensor - Backpass gradient output. - - Returns - ------- - Tuple[ Union[torch.Tensor, None], ...] - A tuple of derivatives wrt. each argument in the forward method - """ - _i_sign = ctx.isign - _mode_ordering = ctx.mode_ordering - finufftkwargs = ctx.finufftkwargs - - points_x, points_y, targets = ctx.saved_tensors - - x_ramp = torch.arange(0, targets.shape[0], dtype=points_x.dtype) - ( - targets.shape[0] // 2 - ) - y_ramp = torch.arange(0, targets.shape[1], dtype=points_y.dtype) - ( - targets.shape[1] // 2 - ) - XX, YY = torch.meshgrid(x_ramp, y_ramp) - - grad_points_x = grad_points_y = grad_targets = None - - if ctx.needs_input_grad[0]: - # wrt. points_x - if _mode_ordering != 0: - XX = torch.fft.ifftshift(XX) - - # TODO analytically work out if we can simplify this *1j, - # the below conj, and below *values - ramped_targets = XX * targets * 1j * _i_sign - - np_points_x = points_x.data.numpy() - np_points_y = points_y.data.numpy() - np_ramped_targets = ramped_targets.data.numpy() - - grad_points_x = torch.from_numpy( - finufft.nufft2d2( - np_points_x, - np_points_y, - np_ramped_targets, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, - ) - ).to(targets.dtype) - - grad_points_x = grad_points_x.conj() - grad_points_x *= grad_output - - grad_points_x = grad_points_x.real - - if ctx.needs_input_grad[1]: - # wrt. points_y - - if _mode_ordering != 0: - YY = torch.fft.ifftshift(YY) - - # TODO analytically work out if we can simplify this *1j, - # the below conj, and below *values - ramped_targets = YY * targets * 1j * _i_sign - - np_points_x = points_x.data.numpy() - np_points_y = points_y.data.numpy() - np_ramped_targets = ramped_targets.data.numpy() - - grad_points_y = torch.from_numpy( - finufft.nufft2d2( - np_points_x, - np_points_y, - np_ramped_targets, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, - ) - ).to(targets.dtype) - - grad_points_y = grad_points_y.conj() - grad_points_y *= grad_output - - grad_points_y = grad_points_y.real - - if ctx.needs_input_grad[2]: - # wrt. targets - - np_points_x = points_x.data.numpy() - np_points_y = points_y.data.numpy() - - np_grad_output = grad_output.data.numpy() - - grad_targets = torch.from_numpy( - finufft.nufft2d1( - np_points_x, - np_points_y, - np_grad_output, - len(targets), - modeord=_mode_ordering, - isign=(-1 * _i_sign), - **finufftkwargs, - ) - ) - - return ( - grad_points_x, - grad_points_y, - grad_targets, - None, - None, - None, - ) - - -############################################################################### -# 3d Functions -############################################################################### - - -class finufft3D2(torch.autograd.Function): - """ - FINUFFT 3D problem type 2 - """ - - @staticmethod - def forward( - ctx: Any, - points_x: torch.Tensor, - points_y: torch.Tensor, - points_z: torch.Tensor, - targets: torch.Tensor, - out: Optional[torch.Tensor] = None, - fftshift: bool = False, - finufftkwargs: Dict[str, Union[int, float]] = {}, - ) -> torch.Tensor: - """ - Evalutes the Type 2 NUFFT on the inputs - - NOTE: By default, the ordering is set to match that of Pytorch, - Numpy, and Scipy's FFT APIs. To match the mode ordering - native to FINUFFT, set fftshift=True. - ``` - c[j] = SUM f[k1, k2, k3] exp(+/-i (k1 x(j) + k2 y(j) + k3 z(j))) - k1, k2, k3 - - for j = 0, ..., M-1, where the sum is over -N1/2 <= k1 <= (N1-1)/2, - -N2/2 <= k2 <= (N2-1)/2, -N3/2 <= k3 <= (N3-1)/2 - ``` - - Parameters - ---------- - ctx : Any - Pytorch context object - points_x : torch.Tensor - The non-uniform points x_j. Valid only between -3pi and 3pi. - points_y : torch.Tensor - The non-uniform points y_j. Valid only between -3pi and 3pi. - points_z : torch.Tensor - The non-uniform points z_j. Valid only between -3pi and 3pi. - targets : torch.Tensor - The target Fourier mode coefficients f_{k1, k2, k3} - out : Optional[torch.Tensor], optional - Array to use for in-place result, by default None - fftshift : bool - If True centers the 0 mode in the resultant array, by default False - finufftkwargs : Dict[str, Union[int, float]] - Additional arguments will be passed into FINUFFT. See - https://finufft.readthedocs.io/en/latest/python.html. By default - an empty dictionary - - Returns - ------- - torch.Tensor - The resultant array c[j] - - Raises - ------ - ValueError - In the case that the mode ordering is double-specified with both - fftshift and the kwarg modeord (only one should be provided). - """ - if out is not None: - print("In-place results are not yet implemented") - - checks._type2_checks((points_x, points_y, points_z), targets) - - finufftkwargs = {k: v for k, v in finufftkwargs.items()} - _mode_ordering = finufftkwargs.pop("modeord", 1) - _i_sign = finufftkwargs.pop("isign", 1) - - if fftshift: - if _mode_ordering != 1: - raise ValueError( - "Double specification of ordering; only one of fftshift " - "and modeord should be provided." - ) - _mode_ordering = 0 - - ctx.isign = _i_sign - ctx.mode_ordering = _mode_ordering - ctx.fftshift = fftshift - ctx.finufftkwargs = finufftkwargs - - ctx.save_for_backward(points_x, points_y, points_z, targets) - - finufft_out = finufft.nufft3d2( - points_x.data.numpy(), - points_y.data.numpy(), - points_z.data.numpy(), - targets.data.numpy(), - modeord=_mode_ordering, - isign=_i_sign, - **finufftkwargs, - ) - - return torch.from_numpy(finufft_out) - - @staticmethod - def backward( - ctx: Any, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: - """ - Implements derivatives wrt. each argument in the forward method - - Parameters - ---------- - ctx : Any - Pytorch context object - grad_output : torch.Tensor - Backpass gradient output - - Returns - ------- - Tuple[Union[torch.Tensor, None], ...] - Tuple of derivatives wrt. each argument in the forward method - """ - _i_sign = ctx.isign - _mode_ordering = ctx.mode_ordering - finufftkwargs = ctx.finufftkwargs - - points_x, points_y, points_z, targets = ctx.saved_tensors - - x_ramp = torch.arange(0, targets.shape[0], dtype=points_x.dtype) - ( - targets.shape[0] // 2 - ) - y_ramp = torch.arange(0, targets.shape[0], dtype=points_x.dtype) - ( - targets.shape[0] // 2 - ) - z_ramp = torch.arange(0, targets.shape[0], dtype=points_x.dtype) - ( - targets.shape[0] // 2 - ) - XX, YY, ZZ = torch.meshgrid(x_ramp, y_ramp, z_ramp) - - grad_points_x = grad_points_y = grad_points_z = grad_values = None - - if ctx.needs_input_grad[0]: - # wrt. points_x - if _mode_ordering != 0: - XX = torch.fft.ifftshift(XX) - - # TODO analytically work out if we can simplify this *1j, - # the below conj, and below *values - ramped_targets = XX * targets * 1j * _i_sign - - np_points_x = points_x.data.numpy() - np_points_y = points_y.data.numpy() - np_points_z = points_z.data.numpy() - np_ramped_targets = ramped_targets.data.numpy() - - grad_points_x = torch.from_numpy( - finufft.nufft3d2( - np_points_x, - np_points_y, - np_points_z, - np_ramped_targets, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, - ) - ).to(targets.dtype) - - grad_points_x = (grad_points_x.conj() * grad_output).real - - if ctx.needs_input_grad[1]: - # wrt. points_y - if _mode_ordering != 0: - YY = torch.fft.ifftshift(YY) - - # TODO analytically work out if we can simplify this *1j, - # the below conj, and below *values - ramped_targets = YY * targets * 1j * _i_sign - - np_points_x = points_x.data.numpy() - np_points_y = points_y.data.numpy() - np_points_z = points_z.data.numpy() - np_ramped_targets = ramped_targets.data.numpy() - - grad_points_y = torch.from_numpy( - finufft.nufft3d2( - np_points_x, - np_points_y, - np_points_z, - np_ramped_targets, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, - ) - ).to(targets.dtype) - - grad_points_y = (grad_points_y.conj() * grad_output).real - - if ctx.needs_input_grad[2]: - # wrt. points_z - if _mode_ordering != 0: - ZZ = torch.fft.ifftshift(ZZ) - - # TODO analytically work out if we can simplify this *1j, - # the below conj, and below *values - ramped_targets = ZZ * targets * 1j * _i_sign - - np_points_x = points_x.data.numpy() - np_points_y = points_y.data.numpy() - np_points_z = points_z.data.numpy() - np_ramped_targets = ramped_targets.data.numpy() - - grad_points_z = torch.from_numpy( - finufft.nufft3d2( - np_points_x, - np_points_y, - np_points_z, - np_ramped_targets, - isign=_i_sign, - modeord=_mode_ordering, - **finufftkwargs, - ) - ).to(targets.dtype) - - grad_points_z = (grad_points_z.conj() * grad_output).real - - if ctx.needs_input_grad[3]: - np_points_x = points_x.data.numpy() - np_points_y = points_y.data.numpy() - np_points_z = points_z.data.numpy() - np_grad_output = grad_output.data.numpy() - - grad_values = torch.from_numpy( - finufft.nufft3d1( - np_points_x, - np_points_y, - np_points_z, - np_grad_output, - len(np_grad_output), - isign=(-1 * _i_sign), - modeord=_mode_ordering, - **finufftkwargs, - ) - ) - - return ( - grad_points_x, - grad_points_y, - grad_points_z, - grad_values, - None, - None, - None, - ) - - -############################################################################### -# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 -############################################################################### +newaxis = None def get_nufft_func( @@ -731,6 +70,7 @@ def coordinate_ramps(shape, device): return coord_ramps + class finufft_type1(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] @@ -822,7 +162,7 @@ def backward( # type: ignore[override] # we can't batch in 1d case so we squeeze and fix up the ouput later ramped_grad_output = ( - coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + coord_ramps * grad_output[newaxis] * 1j * _i_sign ).squeeze() backprop_ramp = nufft_func( *points, ramped_grad_output, isign=_i_sign, **finufftkwargs @@ -847,19 +187,18 @@ def backward( # type: ignore[override] ) - class finufft_type2(torch.autograd.Function): """ FINUFFT 2D problem type 2 """ @staticmethod - def forward( + def forward( # type: ignore[override] ctx: Any, points: torch.Tensor, targets: torch.Tensor, out: Optional[torch.Tensor] = None, - finufftkwargs: Dict[str, Union[int, float]] = None, + finufftkwargs: Optional[Dict[str, Union[int, float]]] = None, ) -> torch.Tensor: """ Evaluates the Type 2 NUFFT on the inputs. @@ -880,7 +219,7 @@ def forward( Array to take the result in-place, by default None finufftkwargs : Dict[str, Union[int, float]] Additional arguments will be passed into FINUFFT. See - https://finufft.readthedocs.io/en/latest/python.html. + https://finufft.readthedocs.io/en/latest/python.html. Returns ------- @@ -898,19 +237,21 @@ def forward( # TODO -- extend checks to 2d checks._type2_checks(points, targets) - if finufftkwargs is None: finufftkwargs = dict() - + finufftkwargs = {k: v for k, v in finufftkwargs.items()} - _mode_ordering = finufftkwargs.pop("modeord", 1) # not finufft default, but corresponds to pytorch default - _i_sign = finufftkwargs.pop("isign", -1) # isign=-1 is finufft default for type 2 + _mode_ordering = finufftkwargs.pop( + "modeord", 1 + ) # not finufft default, but corresponds to pytorch default + _i_sign = finufftkwargs.pop( + "isign", -1 + ) # isign=-1 is finufft default for type 2 ndim = points.shape[0] if _mode_ordering == 1: targets = torch.fft.fftshift(targets) - ctx.isign = _i_sign ctx.mode_ordering = _mode_ordering ctx.finufftkwargs = finufftkwargs @@ -929,16 +270,9 @@ def forward( return finufft_out @staticmethod - def backward( + def backward( # type: ignore[override] ctx: Any, grad_output: torch.Tensor - ) -> Tuple[ - Union[torch.Tensor, None], - Union[torch.Tensor, None], - Union[torch.Tensor, None], - None, - None, - None, - ]: + ) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], None, None, None,]: """ Implements derivatives wrt. each argument in the forward method. @@ -966,15 +300,15 @@ def backward( if ctx.needs_input_grad[0]: coord_ramps = coordinate_ramps(targets.shape, device=device) - ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign + ramped_targets = coord_ramps * targets[newaxis] * 1j * _i_sign nufft_func = get_nufft_func(ndim, 2, points.device.type) grad_points = nufft_func( - *points, - ramped_targets.squeeze(), - isign=_i_sign, - **finufftkwargs, - ).conj() # Currently don't really get why this is hard to replace with a flipped isign + *points, + ramped_targets.squeeze(), + isign=_i_sign, + **finufftkwargs, + ).conj() # Why can't this be replaced with a flipped isign grad_points = grad_points * grad_output grad_points = torch.atleast_2d(grad_points.real) @@ -984,13 +318,13 @@ def backward( nufft_func = get_nufft_func(ndim, 1, points.device.type) grad_targets = nufft_func( - *points, - grad_output, - targets.shape, - isign=-_i_sign, - **finufftkwargs, - ) - + *points, + grad_output, + targets.shape, + isign=-_i_sign, + **finufftkwargs, + ) + if _mode_ordering == 1: grad_targets = torch.fft.ifftshift(grad_targets) @@ -1001,4 +335,3 @@ def backward( None, None, ) - diff --git a/tests/test_1d/test_backward_1d.py b/tests/test_1d/test_backward_1d.py index 988c578..dc33894 100644 --- a/tests/test_1d/test_backward_1d.py +++ b/tests/test_1d/test_backward_1d.py @@ -7,7 +7,7 @@ torch.set_default_tensor_type(torch.DoubleTensor) torch.set_default_dtype(torch.float64) -torch.manual_seed(0) +torch.manual_seed(1234) # Case generation @@ -94,75 +94,6 @@ def test_t1_backward_cuda_points( check_t1_backward(N, modifier, fftshift, isign, "cuda", True) -###################################################################### -# TYPE 2 TESTS -###################################################################### - - -def apply_finufft1d2(fftshift: bool, isign: int): - """Wrapper around finufft1D2.apply(...)""" - - def f(points: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - return pytorch_finufft.functional.finufft1D2.apply( - points, - targets, - None, - fftshift, - dict(isign=isign), - ) - - return f - - -""" -NOTE: A few of the below do NOT pass due to strict tolerance -""" - - -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_targets( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implementation of - targets gradients for NUFFT type 2 in functional. - """ - points = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - targets = torch.randn(N + modifier, dtype=torch.complex128) - - targets.requires_grad = True - points.requires_grad = False - - inputs = (points, targets) - - assert gradcheck(apply_finufft1d2(fftshift, isign), inputs) - - -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implementation of - targets gradients for NUFFT type 2 in functional. - """ - points = 3 * np.pi * ((2 * torch.rand(N, dtype=torch.float64)) - 1) - targets = torch.randn(N + modifier, dtype=torch.complex128) - - targets.requires_grad = False - points.requires_grad = True - - inputs = (points, targets) - - assert gradcheck(apply_finufft1d2(fftshift, isign), inputs, eps=1e-8, atol=1e-5 * N) - - def check_t2_backward( N: int, modifier: int, @@ -187,7 +118,7 @@ def func(points, targets): dict(modeord=int(not fftshift), isign=isign), ) - assert gradcheck(func, inputs, atol=5e-3 * N) + assert gradcheck(func, inputs, eps=1e-8, atol=1.5e-3 * N) @pytest.mark.parametrize("N", Ns) @@ -199,6 +130,7 @@ def test_t2_backward_CPU_values( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -208,6 +140,7 @@ def test_t2_backward_CPU_points( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -217,6 +150,7 @@ def test_t2_backward_cuda_values( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -225,4 +159,3 @@ def test_t2_backward_cuda_points( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", True) - diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 8ca8a9d..a57e537 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -74,35 +74,6 @@ def test_t1_forward_cuda(N: int) -> None: ) -@pytest.mark.parametrize("targets", cases) -def test_1d_t2_forward_CPU(targets: torch.Tensor): - """ - Test type 2 API against existing implementations by setting - """ - N = len(targets) - inv_targets = torch.fft.fft(targets) - assert len(inv_targets) == N - - against_torch = torch.fft.ifft(inv_targets) - - data_type = torch.float64 if targets.dtype is torch.complex128 else torch.float32 - - finufft_out = ( - pytorch_finufft.functional.finufft1D2.apply( - 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=data_type), - inv_targets, - ) - / N - ) - - assert torch.norm(finufft_out - np.array(targets)) / N**2 == pytest.approx( - 0, abs=1e-05 - ) - assert torch.norm(finufft_out - against_torch) / N**2 == pytest.approx( - 0, abs=1e-05 - ) - - @pytest.mark.parametrize("N", Ns) def test_t2_forward_CPU(N: int) -> None: """ @@ -130,7 +101,6 @@ def test_t2_forward_CPU(N: int) -> None: l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) - assert l_inf_error < 4.5e-5 * N ** 1.1 - assert l_2_error < 6e-5 * N ** 2.1 - assert l_1_error < 1.2e-4 * N ** 3.2 - + assert l_inf_error < 4.5e-5 * N**1.1 + assert l_2_error < 6e-5 * N**2.1 + assert l_1_error < 1.2e-4 * N**3.2 diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 88a51bf..b560c45 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -7,7 +7,7 @@ torch.set_default_tensor_type(torch.DoubleTensor) torch.set_default_dtype(torch.float64) -torch.manual_seed(0) +torch.manual_seed(1234) ###################################################################### @@ -105,106 +105,6 @@ def test_t1_backward_cuda_points( check_t1_backward(N, modifier, fftshift, isign, "cuda", True) -###################################################################### -# TYPE 2 TESTS -###################################################################### - - -def apply_finufft2d2(fftshift: bool, isign: int): - """Wrapper around finufft2D1.apply(...)""" - - def f( - points_x: torch.Tensor, points_y: torch.Tensor, targets: torch.Tensor - ) -> torch.Tensor: - return pytorch_finufft.functional.finufft2D2.apply( - points_x, points_y, targets, None, fftshift, dict(isign=isign) - ) - - return f - - -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_targets( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implemntation - of the derivative in targets for 2d NUFFT type 2 - """ - - # TODO -- need to make sure the points are uneven and varied sufficiently - points_x = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_y = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - targets = torch.randn((N, N), dtype=torch.complex128) - - points_x.requires_grad = False - points_y.requires_grad = False - targets.requires_grad = True - - inputs = (points_x, points_y, targets) - - assert gradcheck(apply_finufft2d2(fftshift, isign), inputs) - - # TODO -- have it test also over uneven points_x and points_y - - -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points_x( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implemntation - of the derivative in targets for 2d NUFFT type 2 - """ - - # TODO -- need to make sure the points are uneven and varied sufficiently - - # points_x = 3 * np.pi * (torch.rand(N) - (torch.ones(N) / 2)) - points_x = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_y = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - targets = torch.randn((N, N), dtype=torch.complex128) - - points_x.requires_grad = True - points_y.requires_grad = False - targets.requires_grad = False - - inputs = (points_x, points_y, targets) - - assert gradcheck(apply_finufft2d2(fftshift, isign), inputs) - - -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points_y( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implemntation - of the derivative in targets for 2d NUFFT type 2 - """ - - points_x = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_y = 3 * np.pi * (torch.rand(N) - (torch.ones(N) / 2)) - targets = torch.randn((N, N), dtype=torch.complex128) - - points_x.requires_grad = False - points_y.requires_grad = True - targets.requires_grad = False - - inputs = (points_x, points_y, targets) - - assert gradcheck(apply_finufft2d2(fftshift, isign), inputs) - - - def check_t2_backward( N: int, modifier: int, @@ -241,6 +141,7 @@ def test_t2_backward_CPU_values( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -250,6 +151,7 @@ def test_t2_backward_CPU_points( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -259,6 +161,7 @@ def test_t2_backward_cuda_values( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -267,4 +170,3 @@ def test_t2_backward_cuda_points( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", True) - diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index ba84f45..52a32c4 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -4,7 +4,7 @@ import pytorch_finufft -torch.manual_seed(0) +torch.manual_seed(1234) # Case generation @@ -63,56 +63,6 @@ def test_t1_forward_cuda(N: int) -> None: check_t1_forward(N, "cuda") -@pytest.mark.parametrize("N", Ns) -def test_2d_t2_forward_CPU(N: int) -> None: - """ - Tests against implementations of the FFT by setting up a uniform grid - over which to call FINUFFT through the API. - """ - # Double precision test - g = np.mgrid[:N, :N] * 2 * np.pi / N - x, y = g.reshape(2, -1) - - values = torch.randn(*g[0].shape, dtype=torch.complex128) - - finufft_out = ( - pytorch_finufft.functional.finufft2D2.apply( - torch.from_numpy(x), torch.from_numpy(y), values - ) - ).reshape(g[0].shape) / (N**2) - - against_torch = torch.fft.ifft2(values) - - assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx( - 0, abs=1e-6 - ) - - g = np.mgrid[:N, :N] * 2 * np.pi / N - x, y = g.reshape(2, -1) - - # single precision test - values = torch.randn(*g[0].shape, dtype=torch.complex64) - - finufft_out = ( - pytorch_finufft.functional.finufft2D2.apply( - torch.from_numpy(x).to(torch.float32), - torch.from_numpy(y).to(torch.float32), - values, - ) - ).reshape(g[0].shape) / (N**2) - - against_torch = torch.fft.ifft2(values) - - abs_errors = torch.abs(finufft_out - against_torch) - l_inf_error = abs_errors.max() - l_2_error = torch.sqrt(torch.sum(abs_errors**2)) - l_1_error = torch.sum(abs_errors) - - assert l_inf_error < 1e-5 * N - assert l_2_error < 1e-5 * N**2 - assert l_1_error < 1e-5 * N**3 - - @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("fftshift", [False, True]) def test_t2_forward_CPU(N: int, fftshift: bool) -> None: @@ -133,7 +83,7 @@ def test_t2_forward_CPU(N: int, fftshift: bool) -> None: points, targets, None, - {'modeord': int(not fftshift)}, + {"modeord": int(not fftshift)}, ) if fftshift: @@ -147,6 +97,5 @@ def test_t2_forward_CPU(N: int, fftshift: bool) -> None: l_1_error = torch.sum(abs_errors) assert l_inf_error < 4.5e-5 * N - assert l_2_error < 1e-5 * N ** 2 - assert l_1_error < 1e-5 * N ** 3 - + assert l_2_error < 1e-5 * N**2 + assert l_1_error < 1e-5 * N**3 diff --git a/tests/test_3d/test_backward_3d.py b/tests/test_3d/test_backward_3d.py index 3c63014..62c8e9c 100644 --- a/tests/test_3d/test_backward_3d.py +++ b/tests/test_3d/test_backward_3d.py @@ -7,34 +7,7 @@ torch.set_default_tensor_type(torch.DoubleTensor) torch.set_default_dtype(torch.float64) -torch.manual_seed(0) - - -def apply_finufft3d2(fftshift: bool, isign: int): - """Wrapper around finufft2D1.apply(...)""" - - def f( - points_x: torch.Tensor, - points_y: torch.Tensor, - points_z: torch.Tensor, - targets: torch.Tensor, - ) -> torch.Tensor: - return pytorch_finufft.functional.finufft3D2.apply( - points_x, - points_y, - points_z, - targets, - None, - fftshift, - dict(isign=isign), - ) - - return f - - -###################################################################### -# TEST CASES -###################################################################### +torch.manual_seed(1234) Ns = [ @@ -81,7 +54,7 @@ def func(points, values): dict(modeord=int(not fftshift), isign=isign), ) - assert gradcheck(func, inputs, atol=1e-5 * N) + assert gradcheck(func, inputs, eps=1e-8, atol=1e-5 * N) @pytest.mark.parametrize("N", Ns) @@ -129,120 +102,6 @@ def test_t1_backward_cuda_points( ###################################################################### -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_targets( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implemntation - of the derivative in targets for 2d NUFFT type 2 - """ - - points_x = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_y = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_z = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - targets = torch.randn((N, N, N), dtype=torch.complex128) - - points_x.requires_grad = False - points_y.requires_grad = False - points_z.requires_grad = False - targets.requires_grad = True - - inputs = (points_x, points_y, points_z, targets) - - assert gradcheck(apply_finufft3d2(fftshift, isign), inputs) - - -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points_x( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implemntation - of the derivative in targets for 2d NUFFT type 2 - - N -- size of arrays - modifier -- perturb the size of the arrays with size modifier - fftshift -- test different fftshift values (T/F) - isign -- test different isign values (+/- 1) - """ - - points_x = 3 * np.pi * (torch.rand(N) - (torch.ones(N) / 2)) - points_y = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_z = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - targets = torch.randn((N, N, N), dtype=torch.complex128) - - points_x.requires_grad = True - points_y.requires_grad = False - points_z.requires_grad = False - targets.requires_grad = False - - inputs = (points_x, points_y, points_z, targets) - - assert gradcheck(apply_finufft3d2(fftshift, isign), inputs) - - -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points_y( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implemntation - of the derivative in targets for 2d NUFFT type 2 - """ - - points_x = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_y = 3 * np.pi * (torch.rand(N) - (torch.ones(N) / 2)) - points_z = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - targets = torch.randn((N, N, N), dtype=torch.complex128) - - points_x.requires_grad = False - points_y.requires_grad = True - points_z.requires_grad = False - targets.requires_grad = False - - inputs = (points_x, points_y, points_z, targets) - - assert gradcheck(apply_finufft3d2(fftshift, isign), inputs) - - -@pytest.mark.parametrize("N", Ns) -@pytest.mark.parametrize("modifier", length_modifiers) -@pytest.mark.parametrize("fftshift", [True, False]) -@pytest.mark.parametrize("isign", [-1, 1]) -def test_t2_backward_CPU_points_z( - N: int, modifier: int, fftshift: bool, isign: int -) -> None: - """ - Uses gradcheck to test the correctness of the implemntation - of the derivative in targets for 2d NUFFT type 2 - """ - - points_x = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_y = 2 * np.pi * torch.arange(0, 1, 1 / N, dtype=torch.float64) - points_z = 3 * np.pi * (torch.rand(N) - (torch.ones(N) / 2)) - targets = torch.randn((N, N, N), dtype=torch.complex128) - - points_x.requires_grad = False - points_y.requires_grad = False - points_z.requires_grad = True - targets.requires_grad = False - - inputs = (points_x, points_y, points_z, targets) - - assert gradcheck(apply_finufft3d2(fftshift, isign), inputs) - - - def check_t2_backward( N: int, modifier: int, @@ -279,6 +138,7 @@ def test_t2_backward_CPU_values( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -288,6 +148,7 @@ def test_t2_backward_CPU_points( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -297,6 +158,7 @@ def test_t2_backward_cuda_values( ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + @pytest.mark.parametrize("N", Ns) @pytest.mark.parametrize("modifier", length_modifiers) @pytest.mark.parametrize("fftshift", [False, True]) @@ -305,4 +167,3 @@ def test_t2_backward_cuda_points( N: int, modifier: int, fftshift: bool, isign: int ) -> None: check_t2_backward(N, modifier, fftshift, isign, "cuda", True) - diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index f4cbd98..6ac24ae 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -4,7 +4,7 @@ import pytorch_finufft -torch.manual_seed(0) +torch.manual_seed(1234) # Case generation @@ -62,39 +62,6 @@ def test_t1_forward_cuda(N: int) -> None: check_t1_forward(N, "cuda") -@pytest.mark.parametrize("N", Ns) -def test_3d_t2_forward_CPU(N: int) -> None: - """ - Tests against implementations of the FFT by setting up a uniform grid - over which to call FINUFFT through the API - """ - # Double precision test - - for _ in range(5): - g = np.mgrid[:N, :N, :N] * 2 * np.pi / N - x, y, z = g.reshape(3, -1) - - values = torch.randn(*g[0].shape, dtype=torch.complex128) - - finufft_out = pytorch_finufft.functional.finufft3D2.apply( - torch.from_numpy(x), - torch.from_numpy(y), - torch.from_numpy(z), - values, - ).reshape(g[0].shape) / (N**3) - - against_torch = torch.fft.ifftn(values) - - abs_errors = torch.abs(finufft_out - against_torch) - l_inf_error = abs_errors.max() - l_2_error = torch.sqrt(torch.sum(abs_errors**2)) - l_1_error = torch.sum(abs_errors) - - assert l_inf_error < 1e-5 * N**1.5 - assert l_2_error < 1e-5 * N**3 - assert l_1_error < 1e-5 * N**4.5 - - @pytest.mark.parametrize("N", Ns) def test_t2_forward_CPU(N: int) -> None: """ @@ -122,7 +89,6 @@ def test_t2_forward_CPU(N: int) -> None: l_2_error = torch.sqrt(torch.sum(abs_errors**2)) l_1_error = torch.sum(abs_errors) - assert l_inf_error < 4.5e-5 * N ** 1.1 - assert l_2_error < 6e-5 * N ** 2.1 - assert l_1_error < 1.2e-4 * N ** 3.2 - + assert l_inf_error < 4.5e-5 * N**1.1 + assert l_2_error < 6e-5 * N**2.1 + assert l_1_error < 1.2e-4 * N**3.2 diff --git a/tests/test_errors.py b/tests/test_errors.py index 853388a..9cc34ff 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -5,7 +5,7 @@ import pytorch_finufft -torch.manual_seed(0) +torch.manual_seed(1234) # devices