Skip to content

Commit

Permalink
TST tests passing for 1D. CAVEAT: we may need to check cases for 1-d …
Browse files Browse the repository at this point in the history
…arrays and add a channel axis somewhere
  • Loading branch information
eickenberg committed Oct 5, 2023
1 parent 2b3ac0f commit 290a800
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/test_1d/test_forward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,41 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor):
)


@pytest.mark.parametrize("N", Ns)
def test_t1_forward_CPU(N: int) -> None:
"""
Tests against implementations of the FFT by setting up a uniform grid
over which to call FINUFFT through the API.
"""
g = np.mgrid[:N] * 2 * np.pi / N
g.shape = 1, -1
points = torch.from_numpy(g.reshape(1, -1))

values = torch.randn(*points[0].shape, dtype=torch.complex128)

print("N is " + str(N))
print("shape of points is " + str(points.shape))
print("shape of values is " + str(values.shape))

finufft_out = pytorch_finufft.functional.finufft_type1.apply(
points,
values,
(N,),
)

against_torch = torch.fft.fft(values.reshape(g[0].shape))

abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 4.5e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3



# @pytest.mark.parametrize("values", cases)
# def test_1d_t3_forward_CPU(values: torch.Tensor) -> None:
# """
Expand Down

0 comments on commit 290a800

Please sign in to comment.