Skip to content

Commit

Permalink
FIX apply suggestions from Brian's PR comments including putting coor…
Browse files Browse the repository at this point in the history
…dinate ramps in a helper
  • Loading branch information
eickenberg committed Oct 12, 2023
1 parent 5db77b2 commit f53348c
Showing 1 changed file with 35 additions and 71 deletions.
106 changes: 35 additions & 71 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,21 @@ def f(*args, **kwargs):
return f


def coordinate_ramps(shape, device):
start_points = -(torch.tensor(shape, device=device) // 2)
end_points = start_points + torch.tensor(shape, device=device)
coord_ramps = torch.stack(
torch.meshgrid(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
),
indexing="ij",
)
)

return coord_ramps

class finufft_type1(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
Expand Down Expand Up @@ -799,17 +814,7 @@ def backward( # type: ignore[override]

if ctx.needs_input_grad[0]:
# wrt points
start_points = -(torch.tensor(grad_output.shape, device=device) // 2)
end_points = start_points + torch.tensor(grad_output.shape, device=device)
coord_ramps = torch.stack(
torch.meshgrid(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
),
indexing="ij",
)
)
coord_ramps = coordinate_ramps(grad_output.shape, device)

# we can't batch in 1d case so we squeeze and fix up the ouput later
ramped_grad_output = (
Expand Down Expand Up @@ -850,51 +855,37 @@ def forward(
points: torch.Tensor,
targets: torch.Tensor,
out: Optional[torch.Tensor] = None,
finufftkwargs: Dict[str, Union[int, float]] = {},
finufftkwargs: Dict[str, Union[int, float]] = None,
) -> torch.Tensor:
"""
Evaluates the Type 2 NUFFT on the inputs.
NOTE: By default, the ordering is set to match that of Pytorch,
Numpy, and Scipy's FFT APIs. To match the mode ordering
native to FINUFFT, set fftshift=True.
```
c[j] = SUM f[k1, k2] exp(+/-i (k1 x(j) + k2 y(j)))
k1, k2
for j = 0, ..., M-1, where the sum is over -N1/2 <= k1 <= (N1-1)/2,
-N2/2 <= k2 <= (N2-1)/2
```
native to FINUFFT, add {'modeord': 0} to finufftkwargs.
Parameters
----------
ctx : Any
Pytorch context objecy
points_x : torch.Tensor
The non-uniform points x_j
points_y : torch.Tensor
The non-uniform points y_j
points : torch.Tensor, shape=(ndim, num_points)
The non-uniform points x
targets : torch.Tensor
The target Fourier mode coefficients f[k1, k2]
The values on the input grid
out : Optional[torch.Tensor], optional
Array to take the result in-place, by default None
fftshift : bool
If True centers the 0 mode in the resultant torch.Tensor, by default False
finufftkwargs : Dict[str, Union[int, float]]
Additional arguments will be passed into FINUFFT. See
https://finufft.readthedocs.io/en/latest/python.html. By default
an empty dictionary
https://finufft.readthedocs.io/en/latest/python.html.
Returns
-------
torch.Tensor
The resultant array c[j]
The Fourier transform of the targets grid evaluated at the points `points`
Raises
------
ValueError
In the case of conflicting specification of the wave-mode ordering.
"""

if out is not None:
Expand All @@ -903,21 +894,17 @@ def forward(
# TODO -- extend checks to 2d
checks._type2_checks(points, targets)

finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop("modeord", 1)
_i_sign = finufftkwargs.pop("isign", -1)

# if fftshift:
# if _mode_ordering != 1: # This seems like it is the wrong way round???????
# raise ValueError(
# "Double specification of ordering; only one of fftshift and "
# "modeord should be provided."
# )
# _mode_ordering = 0
if finufftkwargs is None:
finufftkwargs = dict()

finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop("modeord", 1) # not finufft default, but corresponds to pytorch default
_i_sign = finufftkwargs.pop("isign", -1) # isign=-1 is finufft default for type 2

ndim = points.shape[0]
if _mode_ordering == 1:
targets = torch.fft.fftshift(targets, dim=tuple(range(-ndim, 0)))
targets = torch.fft.fftshift(targets)


ctx.isign = _i_sign
Expand All @@ -929,8 +916,8 @@ def forward(
nufft_func = get_nufft_func(ndim, 2, points.device.type)

finufft_out = nufft_func(
*points.data.numpy(),
targets.data.numpy(),
*points,
targets,
isign=_i_sign,
**finufftkwargs,
)
Expand Down Expand Up @@ -970,40 +957,18 @@ def backward(
points, targets = ctx.saved_tensors
device = points.device

# start_points = -(np.array(targets.shape) // 2)
# end_points = start_points + targets.shape
# slices = tuple(slice(start, end) for start, end in zip(start_points, end_points))

# # CPU idiosyncracy that needs to be done differently
# k_ramps = torch.from_numpy(np.mgrid[slices], dtype=points.dtype)

grad_points = grad_targets = None
ndim = points.shape[0]


if ctx.needs_input_grad[0]:
# wrt points
start_points = -(torch.tensor(targets.shape, device=device) // 2)
end_points = start_points + torch.tensor(targets.shape, device=device)
coord_ramps = torch.stack(
torch.meshgrid(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
),
indexing="ij",
)
)

if ctx.needs_input_grad[0]:
coord_ramps = coordinate_ramps(targets.shape, device=device)
ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign
nufft_func = get_nufft_func(ndim, 2, points.device.type)

grad_points = nufft_func(
*points,
ramped_targets.squeeze(),
isign=_i_sign,
#modeord=_mode_ordering,
**finufftkwargs,
).conj() # Currently don't really get why this is hard to replace with a flipped isign

Expand All @@ -1018,13 +983,12 @@ def backward(
*points,
grad_output,
targets.shape,
#modeord=_mode_ordering,
isign=-_i_sign,
**finufftkwargs,
)

if _mode_ordering == 1:
grad_targets = torch.fft.ifftshift(grad_targets, dim=tuple(range(-ndim, 0)))
grad_targets = torch.fft.ifftshift(grad_targets)

return (
grad_points,
Expand Down

0 comments on commit f53348c

Please sign in to comment.