Skip to content

Commit

Permalink
WIP adding tests but found bug in other tests, going to fix those first
Browse files Browse the repository at this point in the history
  • Loading branch information
eickenberg committed Oct 4, 2023
1 parent 18d6ceb commit e2baee3
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 8 deletions.
21 changes: 15 additions & 6 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,10 @@ def backward(
# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1
###############################################################################

def get_nufft_func(dim, nufft_type):
return getattr(finufft, f"nufft{dim}d{nufft_type}")


class finufft_type1(torch.autograd.Function):
@staticmethod
def forward(
Expand All @@ -1621,8 +1625,8 @@ def forward(
# All this requires is a check on the out array to make sure it is the
# correct shape.

err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately

err._type1_checks(points, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately
# ^ make sure these checks check for consistency between output shape and len(points)

if finufftkwargs is None:
finufftkwargs = dict()
Expand All @@ -1645,9 +1649,14 @@ def forward(
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs

# this below should be a pre-check
ndim = points.shape[0]
assert len(output_shape) == ndim

nufft_func = get_nufft_func(ndim, 1)
finufft_out = torch.from_numpy(
finufft.nufft3d1(
*points.data.T.numpy(),
nufft_func(
*points.data.numpy(),
values.data.numpy(),
output_shape,
modeord=_mode_ordering,
Expand Down Expand Up @@ -1705,7 +1714,7 @@ def backward(
for ramp in ramped_grad_output: # we can batch this into finufft
backprop_ramp = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy(),
*points.numpy(),
ramp.data.numpy(),
isign=_i_sign,
modeord=_mode_ordering,
Expand All @@ -1721,7 +1730,7 @@ def backward(

grad_values = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy()
*points.numpy(),
np_grad_output,
isign=_i_sign,
modeord=_mode_ordering,
Expand Down
32 changes: 30 additions & 2 deletions tests/test_2d/test_forward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_2d_t2_forward_CPU(N: int) -> None:

against_torch = torch.fft.ifft2(values)

assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-6
assert torch.abs(finufft_out - against_torch).sum() == pytest.approx(
0, abs=1e-4
)


Expand All @@ -128,3 +128,31 @@ def test_2d_t2_forward_CPU(N: int) -> None:
# assert abs((f - comparison).sum()) / (N**3) == pytest.approx(0, abs=1e-6)

# pass


@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N, :N] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(2, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N, N),
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

assert (finufft_out - against_torch) == pytest.approx(
0, abs=1e-6
)
29 changes: 29 additions & 0 deletions tests/test_3d/test_forward_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,32 @@ def test_3d_t2_forward_CPU(N: int) -> None:
assert (abs((finufft_out - against_torch).sum())) / (N**4) == pytest.approx(
0, abs=1e-6
)



@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N, :N, :N] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(3, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N, N, N),
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

assert abs((finufft_out - against_torch).sum()) / N**3 == pytest.approx(
0, abs=1e-6
)

0 comments on commit e2baee3

Please sign in to comment.