From cea3856d31cfd550f5af06b84beb891c24d623cb Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 10 Oct 2023 14:38:34 -0400 Subject: [PATCH] Replace numpy mgrid with torch.cartesian_product --- pytorch_finufft/functional.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 7fbb166..dc465c6 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -774,6 +774,7 @@ def forward( finufft_out = torch.fft.ifftshift(finufft_out) return finufft_out + @staticmethod def backward( # type: ignore[override] ctx: Any, grad_output: torch.Tensor @@ -798,29 +799,36 @@ def backward( # type: ignore[override] finufftkwargs = ctx.finufftkwargs points, values = ctx.saved_tensors + device = points.device grads_points = None grad_values = None ndim = points.shape[0] - nufft_func = get_nufft_func(ndim, 2, points.device.type) + nufft_func = get_nufft_func(ndim, 2, device.type) if any(ctx.needs_input_grad) and _mode_ordering: grad_output = torch.fft.fftshift(grad_output) if ctx.needs_input_grad[0]: # wrt points - # CPU idiosyncracy that needs to be done differently - start_points = -(np.array(grad_output.shape) // 2) - end_points = start_points + grad_output.shape - slices = tuple( - slice(start, end) for start, end in zip(start_points, end_points) + start_points = -(torch.tensor(grad_output.shape, device=device) // 2) + end_points = start_points + torch.tensor(grad_output.shape, device=device) + coord_ramps = torch.cartesian_prod( + *( + torch.arange(start, end, device=device) + for start, end in zip(start_points, end_points) + ) + ).to(device) + + # we can't batch in 1d case so we squeeze and fix up the ouput later + ramped_grad_output = ( + coord_ramps * grad_output[np.newaxis] * 1j * _i_sign + ).squeeze() + backprop_ramp = nufft_func( + *points, ramped_grad_output, isign=_i_sign, **finufftkwargs ) - coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device) - - ramped_grad_output = (coord_ramps * grad_output[np.newaxis] * 1j * _i_sign).squeeze() - backprop_ramp = nufft_func(*points, ramped_grad_output, isign=_i_sign, **finufftkwargs) grads_points = torch.atleast_2d((backprop_ramp.conj() * values).real) if ctx.needs_input_grad[1]: