From f53348c9f21fd255d0495fbfee32f5b1a9308e86 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Thu, 12 Oct 2023 16:28:46 -0400 Subject: [PATCH] FIX apply suggestions from Brian's PR comments including putting coordinate ramps in a helper --- pytorch_finufft/functional.py | 106 +++++++++++----------------------- 1 file changed, 35 insertions(+), 71 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index daf1ef0..3e423bb 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -712,6 +712,21 @@ def f(*args, **kwargs): return f +def coordinate_ramps(shape, device): + start_points = -(torch.tensor(shape, device=device) // 2) + end_points = start_points + torch.tensor(shape, device=device) + coord_ramps = torch.stack( + torch.meshgrid( + *( + torch.arange(start, end, device=device) + for start, end in zip(start_points, end_points) + ), + indexing="ij", + ) + ) + + return coord_ramps + class finufft_type1(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] @@ -799,17 +814,7 @@ def backward( # type: ignore[override] if ctx.needs_input_grad[0]: # wrt points - start_points = -(torch.tensor(grad_output.shape, device=device) // 2) - end_points = start_points + torch.tensor(grad_output.shape, device=device) - coord_ramps = torch.stack( - torch.meshgrid( - *( - torch.arange(start, end, device=device) - for start, end in zip(start_points, end_points) - ), - indexing="ij", - ) - ) + coord_ramps = coordinate_ramps(grad_output.shape, device) # we can't batch in 1d case so we squeeze and fix up the ouput later ramped_grad_output = ( @@ -850,51 +855,37 @@ def forward( points: torch.Tensor, targets: torch.Tensor, out: Optional[torch.Tensor] = None, - finufftkwargs: Dict[str, Union[int, float]] = {}, + finufftkwargs: Dict[str, Union[int, float]] = None, ) -> torch.Tensor: """ Evaluates the Type 2 NUFFT on the inputs. NOTE: By default, the ordering is set to match that of Pytorch, Numpy, and Scipy's FFT APIs. To match the mode ordering - native to FINUFFT, set fftshift=True. - - ``` - c[j] = SUM f[k1, k2] exp(+/-i (k1 x(j) + k2 y(j))) - k1, k2 - - for j = 0, ..., M-1, where the sum is over -N1/2 <= k1 <= (N1-1)/2, - -N2/2 <= k2 <= (N2-1)/2 - ``` + native to FINUFFT, add {'modeord': 0} to finufftkwargs. Parameters ---------- ctx : Any Pytorch context objecy - points_x : torch.Tensor - The non-uniform points x_j - points_y : torch.Tensor - The non-uniform points y_j + points : torch.Tensor, shape=(ndim, num_points) + The non-uniform points x targets : torch.Tensor - The target Fourier mode coefficients f[k1, k2] + The values on the input grid out : Optional[torch.Tensor], optional Array to take the result in-place, by default None - fftshift : bool - If True centers the 0 mode in the resultant torch.Tensor, by default False finufftkwargs : Dict[str, Union[int, float]] Additional arguments will be passed into FINUFFT. See - https://finufft.readthedocs.io/en/latest/python.html. By default - an empty dictionary + https://finufft.readthedocs.io/en/latest/python.html. Returns ------- torch.Tensor - The resultant array c[j] + The Fourier transform of the targets grid evaluated at the points `points` Raises ------ - ValueError - In the case of conflicting specification of the wave-mode ordering. + """ if out is not None: @@ -903,21 +894,17 @@ def forward( # TODO -- extend checks to 2d checks._type2_checks(points, targets) - finufftkwargs = {k: v for k, v in finufftkwargs.items()} - _mode_ordering = finufftkwargs.pop("modeord", 1) - _i_sign = finufftkwargs.pop("isign", -1) - # if fftshift: - # if _mode_ordering != 1: # This seems like it is the wrong way round??????? - # raise ValueError( - # "Double specification of ordering; only one of fftshift and " - # "modeord should be provided." - # ) - # _mode_ordering = 0 + if finufftkwargs is None: + finufftkwargs = dict() + + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) # not finufft default, but corresponds to pytorch default + _i_sign = finufftkwargs.pop("isign", -1) # isign=-1 is finufft default for type 2 ndim = points.shape[0] if _mode_ordering == 1: - targets = torch.fft.fftshift(targets, dim=tuple(range(-ndim, 0))) + targets = torch.fft.fftshift(targets) ctx.isign = _i_sign @@ -929,8 +916,8 @@ def forward( nufft_func = get_nufft_func(ndim, 2, points.device.type) finufft_out = nufft_func( - *points.data.numpy(), - targets.data.numpy(), + *points, + targets, isign=_i_sign, **finufftkwargs, ) @@ -970,32 +957,11 @@ def backward( points, targets = ctx.saved_tensors device = points.device - # start_points = -(np.array(targets.shape) // 2) - # end_points = start_points + targets.shape - # slices = tuple(slice(start, end) for start, end in zip(start_points, end_points)) - - # # CPU idiosyncracy that needs to be done differently - # k_ramps = torch.from_numpy(np.mgrid[slices], dtype=points.dtype) - grad_points = grad_targets = None ndim = points.shape[0] - - if ctx.needs_input_grad[0]: - # wrt points - start_points = -(torch.tensor(targets.shape, device=device) // 2) - end_points = start_points + torch.tensor(targets.shape, device=device) - coord_ramps = torch.stack( - torch.meshgrid( - *( - torch.arange(start, end, device=device) - for start, end in zip(start_points, end_points) - ), - indexing="ij", - ) - ) - if ctx.needs_input_grad[0]: + coord_ramps = coordinate_ramps(targets.shape, device=device) ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign nufft_func = get_nufft_func(ndim, 2, points.device.type) @@ -1003,7 +969,6 @@ def backward( *points, ramped_targets.squeeze(), isign=_i_sign, - #modeord=_mode_ordering, **finufftkwargs, ).conj() # Currently don't really get why this is hard to replace with a flipped isign @@ -1018,13 +983,12 @@ def backward( *points, grad_output, targets.shape, - #modeord=_mode_ordering, isign=-_i_sign, **finufftkwargs, ) if _mode_ordering == 1: - grad_targets = torch.fft.ifftshift(grad_targets, dim=tuple(range(-ndim, 0))) + grad_targets = torch.fft.ifftshift(grad_targets) return ( grad_points,