Skip to content

Commit

Permalink
Replace numpy mgrid with torch.cartesian_product
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 10, 2023
1 parent aa3c05d commit cea3856
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ def forward(
finufft_out = torch.fft.ifftshift(finufft_out)

return finufft_out

@staticmethod
def backward( # type: ignore[override]
ctx: Any, grad_output: torch.Tensor
Expand All @@ -798,29 +799,36 @@ def backward( # type: ignore[override]
finufftkwargs = ctx.finufftkwargs

points, values = ctx.saved_tensors
device = points.device

grads_points = None
grad_values = None

ndim = points.shape[0]

nufft_func = get_nufft_func(ndim, 2, points.device.type)
nufft_func = get_nufft_func(ndim, 2, device.type)

if any(ctx.needs_input_grad) and _mode_ordering:
grad_output = torch.fft.fftshift(grad_output)

if ctx.needs_input_grad[0]:
# wrt points
# CPU idiosyncracy that needs to be done differently
start_points = -(np.array(grad_output.shape) // 2)
end_points = start_points + grad_output.shape
slices = tuple(
slice(start, end) for start, end in zip(start_points, end_points)
start_points = -(torch.tensor(grad_output.shape, device=device) // 2)
end_points = start_points + torch.tensor(grad_output.shape, device=device)
coord_ramps = torch.cartesian_prod(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
)
).to(device)

# we can't batch in 1d case so we squeeze and fix up the ouput later
ramped_grad_output = (
coord_ramps * grad_output[np.newaxis] * 1j * _i_sign
).squeeze()
backprop_ramp = nufft_func(
*points, ramped_grad_output, isign=_i_sign, **finufftkwargs
)
coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device)

ramped_grad_output = (coord_ramps * grad_output[np.newaxis] * 1j * _i_sign).squeeze()
backprop_ramp = nufft_func(*points, ramped_grad_output, isign=_i_sign, **finufftkwargs)
grads_points = torch.atleast_2d((backprop_ramp.conj() * values).real)

if ctx.needs_input_grad[1]:
Expand Down

0 comments on commit cea3856

Please sign in to comment.