Skip to content

Commit

Permalink
WIP adding tests but found bug in other tests, going to fix those first
Browse files Browse the repository at this point in the history
rebased on top of those changes
  • Loading branch information
eickenberg committed Oct 5, 2023
1 parent 48410ff commit 745c343
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 6 deletions.
21 changes: 15 additions & 6 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,10 @@ def backward(
# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1
###############################################################################

def get_nufft_func(dim, nufft_type):
return getattr(finufft, f"nufft{dim}d{nufft_type}")


class finufft_type1(torch.autograd.Function):
@staticmethod
def forward(
Expand All @@ -1621,8 +1625,8 @@ def forward(
# All this requires is a check on the out array to make sure it is the
# correct shape.

err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately

err._type1_checks(points, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately
# ^ make sure these checks check for consistency between output shape and len(points)

if finufftkwargs is None:
finufftkwargs = dict()
Expand All @@ -1645,9 +1649,14 @@ def forward(
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs

# this below should be a pre-check
ndim = points.shape[0]
assert len(output_shape) == ndim

nufft_func = get_nufft_func(ndim, 1)
finufft_out = torch.from_numpy(
finufft.nufft3d1(
*points.data.T.numpy(),
nufft_func(
*points.data.numpy(),
values.data.numpy(),
output_shape,
modeord=_mode_ordering,
Expand Down Expand Up @@ -1705,7 +1714,7 @@ def backward(
for ramp in ramped_grad_output: # we can batch this into finufft
backprop_ramp = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy(),
*points.numpy(),
ramp.data.numpy(),
isign=_i_sign,
modeord=_mode_ordering,
Expand All @@ -1721,7 +1730,7 @@ def backward(

grad_values = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy()
*points.numpy(),
np_grad_output,
isign=_i_sign,
modeord=_mode_ordering,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_2d/test_forward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,31 @@ def test_2d_t2_forward_CPU(N: int) -> None:
# assert abs((f - comparison).sum()) / (N**3) == pytest.approx(0, abs=1e-6)

# pass


@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N, :N] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(2, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N, N),
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

assert (finufft_out - against_torch) == pytest.approx(
0, abs=1e-6
)
28 changes: 28 additions & 0 deletions tests/test_3d/test_forward_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,31 @@ def test_3d_t2_forward_CPU(N: int) -> None:
assert l_inf_error < 1e-5 * N ** 1.5
assert l_2_error < 1e-5 * N ** 3
assert l_1_error < 1e-5 * N ** 4.5


@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N, :N, :N] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(3, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N, N, N),
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

assert abs((finufft_out - against_torch).sum()) / N**3 == pytest.approx(
0, abs=1e-6
)

0 comments on commit 745c343

Please sign in to comment.