Skip to content

Commit

Permalink
FIX 2D backward working
Browse files Browse the repository at this point in the history
  • Loading branch information
eickenberg committed Oct 12, 2023
1 parent c329e01 commit 729dac0
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,6 @@ def forward(
points: torch.Tensor,
targets: torch.Tensor,
out: Optional[torch.Tensor] = None,
fftshift: bool = False,
finufftkwargs: Dict[str, Union[int, float]] = {},
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -908,13 +907,13 @@ def forward(
_mode_ordering = finufftkwargs.pop("modeord", 1)
_i_sign = finufftkwargs.pop("isign", -1)

if fftshift:
if _mode_ordering != 1: # This seems like it is the wrong way round???????
raise ValueError(
"Double specification of ordering; only one of fftshift and "
"modeord should be provided."
)
_mode_ordering = 0
# if fftshift:
# if _mode_ordering != 1: # This seems like it is the wrong way round???????
# raise ValueError(
# "Double specification of ordering; only one of fftshift and "
# "modeord should be provided."
# )
# _mode_ordering = 0

ndim = points.shape[0]
if _mode_ordering == 1:
Expand All @@ -923,7 +922,6 @@ def forward(

ctx.isign = _i_sign
ctx.mode_ordering = _mode_ordering
ctx.fftshift = fftshift
ctx.finufftkwargs = finufftkwargs

ctx.save_for_backward(points, targets)
Expand Down Expand Up @@ -980,10 +978,8 @@ def backward(
# k_ramps = torch.from_numpy(np.mgrid[slices], dtype=points.dtype)

grad_points = grad_targets = None
ndim = points.shape[0]

## From type 1, commenting for now to understand whether needed
# if any(ctx.needs_input_grad) and _mode_ordering:
# grad_output = torch.fft.fftshift(grad_output)

if ctx.needs_input_grad[0]:
# wrt points
Expand All @@ -999,8 +995,6 @@ def backward(
)
)

ndim = points.shape[0]

if ctx.needs_input_grad[0]:
ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign
nufft_func = get_nufft_func(ndim, 2, points.device.type)
Expand Down Expand Up @@ -1029,6 +1023,8 @@ def backward(
**finufftkwargs,
)

if _mode_ordering == 1:
grad_targets = torch.fft.ifftshift(grad_targets, dim=tuple(range(-ndim, 0)))

return (
grad_points,
Expand Down

0 comments on commit 729dac0

Please sign in to comment.