From aa3c05d3bb138e55260c8306033b22183b3d863d Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 10 Oct 2023 14:16:50 -0400 Subject: [PATCH] Batch t1 backward pass for points --- pytorch_finufft/functional.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index e523758..7fbb166 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -819,23 +819,9 @@ def backward( # type: ignore[override] ) coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device) - ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign - - grads_points_ = [] - for ramp in ramped_grad_output: # we can batch this into finufft - - backprop_ramp = nufft_func( - *points, - ramp, - isign=_i_sign, - **finufftkwargs, - ) - - grad_points = (backprop_ramp.conj() * values).real - - grads_points_.append(grad_points) - - grads_points = torch.stack(grads_points_) + ramped_grad_output = (coord_ramps * grad_output[np.newaxis] * 1j * _i_sign).squeeze() + backprop_ramp = nufft_func(*points, ramped_grad_output, isign=_i_sign, **finufftkwargs) + grads_points = torch.atleast_2d((backprop_ramp.conj() * values).real) if ctx.needs_input_grad[1]: grad_values = nufft_func(