Skip to content

Commit

Permalink
Remove double fftshift in t1 backward
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 10, 2023
1 parent c15d2c1 commit 541895d
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,8 @@ def forward(
finufft_out = torch.fft.ifftshift(finufft_out)

return finufft_out

@staticmethod
def backward(
def backward( # type: ignore[override]
ctx: Any, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""
Expand All @@ -800,36 +799,30 @@ def backward(

points, values = ctx.saved_tensors

start_points = -(np.array(grad_output.shape) // 2)
end_points = start_points + grad_output.shape
slices = tuple(
slice(start, end) for start, end in zip(start_points, end_points)
)

# CPU idiosyncracy that needs to be done differently
coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device)

grads_points = None
grad_values = None

ndim = points.shape[0]

nufft_func = get_nufft_func(ndim, 2, points.device.type)

if any(ctx.needs_input_grad) and _mode_ordering:
grad_output = torch.fft.fftshift(grad_output)

if ctx.needs_input_grad[0]:
# wrt points

if _mode_ordering:
coord_ramps = torch.fft.ifftshift(
coord_ramps, dim=tuple(range(1, ndim + 1))
)
# CPU idiosyncracy that needs to be done differently
start_points = -(np.array(grad_output.shape) // 2)
end_points = start_points + grad_output.shape
slices = tuple(
slice(start, end) for start, end in zip(start_points, end_points)
)
coord_ramps = torch.from_numpy(np.mgrid[slices]).to(points.device)

ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign

grads_points = []
grads_points_ = []
for ramp in ramped_grad_output: # we can batch this into finufft
if _mode_ordering:
ramp = torch.fft.fftshift(ramp)

backprop_ramp = nufft_func(
*points,
Expand All @@ -840,14 +833,11 @@ def backward(

grad_points = (backprop_ramp.conj() * values).real

grads_points.append(grad_points)
grads_points_.append(grad_points)

grads_points = torch.stack(grads_points)
grads_points = torch.stack(grads_points_)

if ctx.needs_input_grad[1]:
if _mode_ordering:
grad_output = torch.fft.fftshift(grad_output)

grad_values = nufft_func(
*points,
grad_output,
Expand Down

0 comments on commit 541895d

Please sign in to comment.