Skip to content

Commit

Permalink
Batch t1 backward pass for points
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 10, 2023
1 parent 541895d commit aa3c05d
Showing 1 changed file with 3 additions and 17 deletions.
20 changes: 3 additions & 17 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,23 +819,9 @@ def backward( # type: ignore[override]
)
coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device)

ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign

grads_points_ = []
for ramp in ramped_grad_output: # we can batch this into finufft

backprop_ramp = nufft_func(
*points,
ramp,
isign=_i_sign,
**finufftkwargs,
)

grad_points = (backprop_ramp.conj() * values).real

grads_points_.append(grad_points)

grads_points = torch.stack(grads_points_)
ramped_grad_output = (coord_ramps * grad_output[np.newaxis] * 1j * _i_sign).squeeze()
backprop_ramp = nufft_func(*points, ramped_grad_output, isign=_i_sign, **finufftkwargs)
grads_points = torch.atleast_2d((backprop_ramp.conj() * values).real)

if ctx.needs_input_grad[1]:
grad_values = nufft_func(
Expand Down

0 comments on commit aa3c05d

Please sign in to comment.