diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 0546f62..ab88103 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -1601,6 +1601,10 @@ def backward( # Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1 ############################################################################### +def get_nufft_func(dim, nufft_type): + return getattr(finufft, f"nufft{dim}d{nufft_type}") + + class finufft_type1(torch.autograd.Function): @staticmethod def forward( @@ -1621,8 +1625,8 @@ def forward( # All this requires is a check on the out array to make sure it is the # correct shape. - err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately - + err._type1_checks(points, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately + # ^ make sure these checks check for consistency between output shape and len(points) if finufftkwargs is None: finufftkwargs = dict() @@ -1645,9 +1649,14 @@ def forward( ctx.mode_ordering = _mode_ordering ctx.finufftkwargs = finufftkwargs + # this below should be a pre-check + ndim = points.shape[0] + assert len(output_shape) == ndim + + nufft_func = get_nufft_func(ndim, 1) finufft_out = torch.from_numpy( - finufft.nufft3d1( - *points.data.T.numpy(), + nufft_func( + *points.data.numpy(), values.data.numpy(), output_shape, modeord=_mode_ordering, @@ -1705,7 +1714,7 @@ def backward( for ramp in ramped_grad_output: # we can batch this into finufft backprop_ramp = torch.from_numpy( finufft.nufft3d2( - *points.T.numpy(), + *points.numpy(), ramp.data.numpy(), isign=_i_sign, modeord=_mode_ordering, @@ -1721,7 +1730,7 @@ def backward( grad_values = torch.from_numpy( finufft.nufft3d2( - *points.T.numpy() + *points.numpy(), np_grad_output, isign=_i_sign, modeord=_mode_ordering, diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index 3a61938..78ff7e4 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -109,8 +109,8 @@ def test_2d_t2_forward_CPU(N: int) -> None: against_torch = torch.fft.ifft2(values) - assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx( - 0, abs=1e-6 + assert torch.abs(finufft_out - against_torch).sum() == pytest.approx( + 0, abs=1e-4 ) @@ -128,3 +128,31 @@ def test_2d_t2_forward_CPU(N: int) -> None: # assert abs((f - comparison).sum()) / (N**3) == pytest.approx(0, abs=1e-6) # pass + + +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(2, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N, N), + ) + + against_torch = torch.fft.fft2(values.reshape(g[0].shape)) + + assert (finufft_out - against_torch) == pytest.approx( + 0, abs=1e-6 + ) diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index e470abc..1aca14b 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -72,3 +72,32 @@ def test_3d_t2_forward_CPU(N: int) -> None: assert (abs((finufft_out - against_torch).sum())) / (N**4) == pytest.approx( 0, abs=1e-6 ) + + + +@pytest.mark.parametrize("N", Ns) +def test_t1_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(3, -1)) + + values = torch.randn(*points[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of values is " + str(values.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type1.apply( + points, + values, + (N, N, N), + ) + + against_torch = torch.fft.fft2(values.reshape(g[0].shape)) + + assert abs((finufft_out - against_torch).sum()) / N**3 == pytest.approx( + 0, abs=1e-6 + ) \ No newline at end of file