diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index cce4635..e523758 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -774,9 +774,8 @@ def forward( finufft_out = torch.fft.ifftshift(finufft_out) return finufft_out - @staticmethod - def backward( + def backward( # type: ignore[override] ctx: Any, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: """ @@ -800,15 +799,6 @@ def backward( points, values = ctx.saved_tensors - start_points = -(np.array(grad_output.shape) // 2) - end_points = start_points + grad_output.shape - slices = tuple( - slice(start, end) for start, end in zip(start_points, end_points) - ) - - # CPU idiosyncracy that needs to be done differently - coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device) - grads_points = None grad_values = None @@ -816,20 +806,23 @@ def backward( nufft_func = get_nufft_func(ndim, 2, points.device.type) + if any(ctx.needs_input_grad) and _mode_ordering: + grad_output = torch.fft.fftshift(grad_output) + if ctx.needs_input_grad[0]: # wrt points - - if _mode_ordering: - coord_ramps = torch.fft.ifftshift( - coord_ramps, dim=tuple(range(1, ndim + 1)) - ) + # CPU idiosyncracy that needs to be done differently + start_points = -(np.array(grad_output.shape) // 2) + end_points = start_points + grad_output.shape + slices = tuple( + slice(start, end) for start, end in zip(start_points, end_points) + ) + coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device) ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign - grads_points = [] + grads_points_ = [] for ramp in ramped_grad_output: # we can batch this into finufft - if _mode_ordering: - ramp = torch.fft.fftshift(ramp) backprop_ramp = nufft_func( *points, @@ -840,14 +833,11 @@ def backward( grad_points = (backprop_ramp.conj() * values).real - grads_points.append(grad_points) + grads_points_.append(grad_points) - grads_points = torch.stack(grads_points) + grads_points = torch.stack(grads_points_) if ctx.needs_input_grad[1]: - if _mode_ordering: - grad_output = torch.fft.fftshift(grad_output) - grad_values = nufft_func( *points, grad_output,