Skip to content

Commit

Permalink
Merge pull request #64 from flatironinstitute/mike-consolidate-dimens…
Browse files Browse the repository at this point in the history
…ionalities-type-1

WIP consolidation of dimensionalities for nufft type 1
  • Loading branch information
WardBrian authored Oct 6, 2023
2 parents d8a8e4b + f63c4d9 commit de1fa85
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 32 deletions.
159 changes: 159 additions & 0 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import finufft
import torch

Expand Down Expand Up @@ -1592,3 +1593,161 @@ def backward(
None,
None,
)





###############################################################################
# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1
###############################################################################

def get_nufft_func(dim, nufft_type):
return getattr(finufft, f"nufft{dim}d{nufft_type}")


class finufft_type1(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
points: torch.Tensor,
values: torch.Tensor,
output_shape: Union[int, tuple[int, int], tuple[int, int, int]],
out: Optional[torch.Tensor]=None,
fftshift: bool=False,
finufftkwargs: dict[str, Union[int, float]]=None):
"""
Evaluates the Type 1 NUFFT on the inputs.
"""

if out is not None:
print("In-place results are not yet implemented")
# All this requires is a check on the out array to make sure it is the
# correct shape.

err._type1_checks(points, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately
# ^ make sure these checks check for consistency between output shape and len(points)

if finufftkwargs is None:
finufftkwargs = dict()
finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop("modeord", 1)
_i_sign = finufftkwargs.pop("isign", -1)

if fftshift:
# TODO -- this check should be done elsewhere? or error msg changed
# to note instead that there is a conflict in fftshift
if _mode_ordering != 1:
raise ValueError(
"Double specification of ordering; only one of fftshift and modeord should be provided"
)
_mode_ordering = 0

ctx.save_for_backward(points, values)

ctx.isign = _i_sign
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs

# this below should be a pre-check
ndim = points.shape[0]
assert len(output_shape) == ndim

nufft_func = get_nufft_func(ndim, 1)
finufft_out = torch.from_numpy(
nufft_func(
*points.data.numpy(),
values.data.numpy(),
output_shape,
modeord=_mode_ordering,
isign=_i_sign,
**finufftkwargs,
)
)

return finufft_out

@staticmethod
def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Union[torch.Tensor, None], ...]:
"""
Implements derivatives wrt. each argument in the forward method.
Parameters
----------
ctx : Any
Pytorch context object.
grad_output : torch.Tensor
Backpass gradient output
Returns
-------
tuple[Union[torch.Tensor, None], ...]
A tuple of derivatives wrt. each argument in the forward method
"""
_i_sign = -1 * ctx.isign
_mode_ordering = ctx.mode_ordering
finufftkwargs = ctx.finufftkwargs

points, values = ctx.saved_tensors

start_points = -(np.array(grad_output.shape) // 2)
end_points = start_points + grad_output.shape
slices = tuple(slice(start, end) for start, end in zip(start_points, end_points))

# CPU idiosyncracy that needs to be done differently
coord_ramps = torch.from_numpy(np.mgrid[slices])

grads_points = None
grad_values = None

ndim = points.shape[0]

nufft_func = get_nufft_func(ndim, 2)

if ctx.needs_input_grad[0]:
# wrt points

if _mode_ordering != 0:
coord_ramps = torch.fft.ifftshift(coord_ramps, dim=tuple(range(1, ndim+1)))

ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign

grads_points = []
for ramp in ramped_grad_output: # we can batch this into finufft
backprop_ramp = torch.from_numpy(
nufft_func(
*points.numpy(),
ramp.data.numpy(),
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
))
grad_points = (backprop_ramp.conj() * values).real
grads_points.append(grad_points)

grads_points = torch.stack(grads_points)

if ctx.needs_input_grad[1]:
np_grad_output = grad_output.data.numpy()

grad_values = torch.from_numpy(
nufft_func(
*points.numpy(),
np_grad_output,
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
)
)

return (
grads_points,
grad_values,
None,
None,
None,
None,
)
45 changes: 45 additions & 0 deletions tests/test_1d/test_forward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def test_1d_t1_forward_CPU(values: torch.Tensor) -> None:
) == pytest.approx(0, abs=1e-06)


abs_errors = torch.abs(finufft1D1_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 3.5e-3 * N ** .6
assert l_2_error < 7.5e-4 * N ** 1.1
assert l_1_error < 5e-4 * N ** 1.6


@pytest.mark.parametrize("targets", cases)
def test_1d_t2_forward_CPU(targets: torch.Tensor):
"""
Expand Down Expand Up @@ -96,6 +106,41 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor):
)


@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N] * 2 * np.pi / N
g.shape = 1, -1
points = torch.from_numpy(g.reshape(1, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N,),
)

against_torch = torch.fft.fft(values.reshape(g[0].shape))

abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 4.5e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3



# @pytest.mark.parametrize("values", cases)
# def test_1d_t3_forward_CPU(values: torch.Tensor) -> None:
# """
Expand Down
47 changes: 47 additions & 0 deletions tests/test_2d/test_backward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

import pytorch_finufft

from functools import partial

torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

######################################################################
# APPLY WRAPPERS
Expand Down Expand Up @@ -97,6 +100,50 @@ def test_t1_backward_CPU_values(
assert gradcheck(apply_finufft2d1(modifier, fftshift, isign), inputs)


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t1_consolidated_backward_CPU_values(N: int, modifier: int, fftshift: bool, isign: int) -> None:

points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi
values = torch.randn(N, dtype=torch.complex128)

points.requires_grad = False
values.requires_grad = True

inputs = (points, values)

def func(points, values):
return pytorch_finufft.functional.finufft_type1.apply(
points, values, (N,N + modifier), None, fftshift, dict(isign=isign)
)

assert gradcheck(func, inputs)


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t1_consolidated_backward_CPU_points(N: int, modifier: int, fftshift: bool, isign: int) -> None:

points = torch.rand((2, N), dtype=torch.float64) * 2 * np.pi
values = torch.randn(N, dtype=torch.complex128)

points.requires_grad = True
values.requires_grad = False

inputs = (points, values)

def func(points, values):
return pytorch_finufft.functional.finufft_type1.apply(
points, values, (N,N + modifier), None, fftshift, dict(isign=isign)
)

assert gradcheck(func, inputs, atol=1e-5 * N)


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [True, False])
Expand Down
74 changes: 50 additions & 24 deletions tests/test_2d/test_forward_2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
import torch
torch.manual_seed(0)

import pytorch_finufft

Expand Down Expand Up @@ -45,28 +46,14 @@ def test_2d_t1_forward_CPU(N: int) -> None:

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-6
)
abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

values = torch.randn(*x.shape, dtype=torch.complex64)

finufft_out = pytorch_finufft.functional.finufft2D1.apply(
torch.from_numpy(x).to(torch.float32),
torch.from_numpy(y).to(torch.float32),
values,
N,
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

# NOTE -- the below tolerance is set to 1e-5 instead of -6 due
# to the occasional failing case that seems to be caused by
# the randomness of the test cases in addition to the expected
# accruation of numerical inaccuracies
assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-5
)
assert l_inf_error < 5e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3


@pytest.mark.parametrize("N", Ns)
Expand Down Expand Up @@ -109,9 +96,14 @@ def test_2d_t2_forward_CPU(N: int) -> None:

against_torch = torch.fft.ifft2(values)

assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-6
)
abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 1e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3


# @pytest.mark.parametrize("N", Ns)
Expand All @@ -128,3 +120,37 @@ def test_2d_t2_forward_CPU(N: int) -> None:
# assert abs((f - comparison).sum()) / (N**3) == pytest.approx(0, abs=1e-6)

# pass


@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N, :N] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(2, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N, N),
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 4.5e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3

Loading

0 comments on commit de1fa85

Please sign in to comment.