Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy up the backwards pass of type 1 #74

Merged
merged 4 commits into from
Oct 10, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 22 additions & 38 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def forward(
return finufft_out

@staticmethod
def backward(
def backward( # type: ignore[override]
ctx: Any, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""
Expand All @@ -799,55 +799,39 @@ def backward(
finufftkwargs = ctx.finufftkwargs

points, values = ctx.saved_tensors

start_points = -(np.array(grad_output.shape) // 2)
end_points = start_points + grad_output.shape
slices = tuple(
slice(start, end) for start, end in zip(start_points, end_points)
)

# CPU idiosyncracy that needs to be done differently
coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device)
device = points.device

grads_points = None
grad_values = None

ndim = points.shape[0]

nufft_func = get_nufft_func(ndim, 2, points.device.type)
nufft_func = get_nufft_func(ndim, 2, device.type)

if any(ctx.needs_input_grad) and _mode_ordering:
grad_output = torch.fft.fftshift(grad_output)

if ctx.needs_input_grad[0]:
# wrt points

if _mode_ordering:
coord_ramps = torch.fft.ifftshift(
coord_ramps, dim=tuple(range(1, ndim + 1))
start_points = -(torch.tensor(grad_output.shape, device=device) // 2)
end_points = start_points + torch.tensor(grad_output.shape, device=device)
coord_ramps = torch.cartesian_prod(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
)

ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign

grads_points = []
for ramp in ramped_grad_output: # we can batch this into finufft
if _mode_ordering:
ramp = torch.fft.fftshift(ramp)

backprop_ramp = nufft_func(
*points,
ramp,
isign=_i_sign,
**finufftkwargs,
)

grad_points = (backprop_ramp.conj() * values).real

grads_points.append(grad_points)

grads_points = torch.stack(grads_points)
).to(device)
WardBrian marked this conversation as resolved.
Show resolved Hide resolved

# we can't batch in 1d case so we squeeze and fix up the ouput later
ramped_grad_output = (
coord_ramps * grad_output[np.newaxis] * 1j * _i_sign
).squeeze()
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
backprop_ramp = nufft_func(
*points, ramped_grad_output, isign=_i_sign, **finufftkwargs
)
grads_points = torch.atleast_2d((backprop_ramp.conj() * values).real)
WardBrian marked this conversation as resolved.
Show resolved Hide resolved

if ctx.needs_input_grad[1]:
if _mode_ordering:
grad_output = torch.fft.fftshift(grad_output)

grad_values = nufft_func(
*points,
grad_output,
Expand Down
Loading