diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 079ce9a..affb51a 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -716,6 +716,21 @@ def f(*args, **kwargs): return f +def coordinate_ramps(shape, device): + start_points = -(torch.tensor(shape, device=device) // 2) + end_points = start_points + torch.tensor(shape, device=device) + coord_ramps = torch.stack( + torch.meshgrid( + *( + torch.arange(start, end, device=device) + for start, end in zip(start_points, end_points) + ), + indexing="ij", + ) + ) + + return coord_ramps + class finufft_type1(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] @@ -803,17 +818,7 @@ def backward( # type: ignore[override] if ctx.needs_input_grad[0]: # wrt points - start_points = -(torch.tensor(grad_output.shape, device=device) // 2) - end_points = start_points + torch.tensor(grad_output.shape, device=device) - coord_ramps = torch.stack( - torch.meshgrid( - *( - torch.arange(start, end, device=device) - for start, end in zip(start_points, end_points) - ), - indexing="ij", - ) - ) + coord_ramps = coordinate_ramps(grad_output.shape, device) # we can't batch in 1d case so we squeeze and fix up the ouput later ramped_grad_output = ( @@ -840,3 +845,160 @@ def backward( # type: ignore[override] None, None, ) + + + +class finufft_type2(torch.autograd.Function): + """ + FINUFFT 2D problem type 2 + """ + + @staticmethod + def forward( + ctx: Any, + points: torch.Tensor, + targets: torch.Tensor, + out: Optional[torch.Tensor] = None, + finufftkwargs: Dict[str, Union[int, float]] = None, + ) -> torch.Tensor: + """ + Evaluates the Type 2 NUFFT on the inputs. + + NOTE: By default, the ordering is set to match that of Pytorch, + Numpy, and Scipy's FFT APIs. To match the mode ordering + native to FINUFFT, add {'modeord': 0} to finufftkwargs. + + Parameters + ---------- + ctx : Any + Pytorch context objecy + points : torch.Tensor, shape=(ndim, num_points) + The non-uniform points x + targets : torch.Tensor + The values on the input grid + out : Optional[torch.Tensor], optional + Array to take the result in-place, by default None + finufftkwargs : Dict[str, Union[int, float]] + Additional arguments will be passed into FINUFFT. See + https://finufft.readthedocs.io/en/latest/python.html. + + Returns + ------- + torch.Tensor + The Fourier transform of the targets grid evaluated at the points `points` + + Raises + ------ + + """ + + if out is not None: + print("In-place results are not yet implemented") + + # TODO -- extend checks to 2d + checks._type2_checks(points, targets) + + + if finufftkwargs is None: + finufftkwargs = dict() + + finufftkwargs = {k: v for k, v in finufftkwargs.items()} + _mode_ordering = finufftkwargs.pop("modeord", 1) # not finufft default, but corresponds to pytorch default + _i_sign = finufftkwargs.pop("isign", -1) # isign=-1 is finufft default for type 2 + + ndim = points.shape[0] + if _mode_ordering == 1: + targets = torch.fft.fftshift(targets) + + + ctx.isign = _i_sign + ctx.mode_ordering = _mode_ordering + ctx.finufftkwargs = finufftkwargs + + ctx.save_for_backward(points, targets) + + nufft_func = get_nufft_func(ndim, 2, points.device.type) + + finufft_out = nufft_func( + *points, + targets, + isign=_i_sign, + **finufftkwargs, + ) + + return finufft_out + + @staticmethod + def backward( + ctx: Any, grad_output: torch.Tensor + ) -> Tuple[ + Union[torch.Tensor, None], + Union[torch.Tensor, None], + Union[torch.Tensor, None], + None, + None, + None, + ]: + """ + Implements derivatives wrt. each argument in the forward method. + + Parameters + ---------- + ctx : Any + Pytorch context object + grad_output : torch.Tensor + Backpass gradient output. + + Returns + ------- + Tuple[ Union[torch.Tensor, None], ...] + A tuple of derivatives wrt. each argument in the forward method + """ + _i_sign = ctx.isign + _mode_ordering = ctx.mode_ordering + finufftkwargs = ctx.finufftkwargs + + points, targets = ctx.saved_tensors + device = points.device + + grad_points = grad_targets = None + ndim = points.shape[0] + + if ctx.needs_input_grad[0]: + coord_ramps = coordinate_ramps(targets.shape, device=device) + ramped_targets = coord_ramps * targets[np.newaxis] * 1j * _i_sign + nufft_func = get_nufft_func(ndim, 2, points.device.type) + + grad_points = nufft_func( + *points, + ramped_targets.squeeze(), + isign=_i_sign, + **finufftkwargs, + ).conj() # Currently don't really get why this is hard to replace with a flipped isign + + grad_points = grad_points * grad_output + grad_points = torch.atleast_2d(grad_points.real) + + if ctx.needs_input_grad[1]: + # wrt. targets + nufft_func = get_nufft_func(ndim, 1, points.device.type) + + grad_targets = nufft_func( + *points, + grad_output, + targets.shape, + isign=-_i_sign, + **finufftkwargs, + ) + + if _mode_ordering == 1: + grad_targets = torch.fft.ifftshift(grad_targets) + + return ( + grad_points, + grad_targets, + None, + None, + None, + ) + diff --git a/tests/test_1d/test_backward_1d.py b/tests/test_1d/test_backward_1d.py index 43793ed..988c578 100644 --- a/tests/test_1d/test_backward_1d.py +++ b/tests/test_1d/test_backward_1d.py @@ -161,3 +161,68 @@ def test_t2_backward_CPU_points( inputs = (points, targets) assert gradcheck(apply_finufft1d2(fftshift, isign), inputs, eps=1e-8, atol=1e-5 * N) + + +def check_t2_backward( + N: int, + modifier: int, + fftshift: bool, + isign: int, + device: str, + points_or_targets: bool, +) -> None: + points = torch.rand((1, N + modifier), dtype=torch.float64).to(device) * 2 * np.pi + targets = torch.randn(N, dtype=torch.complex128).to(device) + + points.requires_grad = points_or_targets + targets.requires_grad = not points_or_targets + + inputs = (points, targets) + + def func(points, targets): + return pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + None, + dict(modeord=int(not fftshift), isign=isign), + ) + + assert gradcheck(func, inputs, atol=5e-3 * N) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_cuda_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_cuda_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", True) + diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index 5f09c24..8ca8a9d 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -101,3 +101,36 @@ def test_1d_t2_forward_CPU(targets: torch.Tensor): assert torch.norm(finufft_out - against_torch) / N**2 == pytest.approx( 0, abs=1e-05 ) + + +@pytest.mark.parametrize("N", Ns) +def test_t2_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N,] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(g.shape[0], -1)) + + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + ) + + against_torch = torch.fft.fftn(targets) + + abs_errors = torch.abs(finufft_out - against_torch.ravel()) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N ** 1.1 + assert l_2_error < 6e-5 * N ** 2.1 + assert l_1_error < 1.2e-4 * N ** 3.2 + diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 903595c..88a51bf 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -202,3 +202,69 @@ def test_t2_backward_CPU_points_y( inputs = (points_x, points_y, targets) assert gradcheck(apply_finufft2d2(fftshift, isign), inputs) + + + +def check_t2_backward( + N: int, + modifier: int, + fftshift: bool, + isign: int, + device: str, + points_or_targets: bool, +) -> None: + points = torch.rand((2, N + modifier), dtype=torch.float64).to(device) * 2 * np.pi + targets = torch.randn(N, N, dtype=torch.complex128).to(device) + + points.requires_grad = points_or_targets + targets.requires_grad = not points_or_targets + + inputs = (points, targets) + + def func(points, targets): + return pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + None, + dict(modeord=int(not fftshift), isign=isign), + ) + + assert gradcheck(func, inputs, atol=1e-5 * N) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_cuda_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_cuda_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", True) + diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index ad6ac50..ba84f45 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -9,6 +9,7 @@ # Case generation Ns = [ + 3, 10, 15, 75, @@ -110,3 +111,42 @@ def test_2d_t2_forward_CPU(N: int) -> None: assert l_inf_error < 1e-5 * N assert l_2_error < 1e-5 * N**2 assert l_1_error < 1e-5 * N**3 + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("fftshift", [False, True]) +def test_t2_forward_CPU(N: int, fftshift: bool) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(2, -1)) + + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + None, + {'modeord': int(not fftshift)}, + ) + + if fftshift: + against_torch = torch.fft.fft2(torch.fft.ifftshift(targets)) + else: + against_torch = torch.fft.fft2(targets) + + abs_errors = torch.abs(finufft_out - against_torch.ravel()) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N + assert l_2_error < 1e-5 * N ** 2 + assert l_1_error < 1e-5 * N ** 3 + diff --git a/tests/test_3d/test_backward_3d.py b/tests/test_3d/test_backward_3d.py index 8dd41ea..3c63014 100644 --- a/tests/test_3d/test_backward_3d.py +++ b/tests/test_3d/test_backward_3d.py @@ -240,3 +240,69 @@ def test_t2_backward_CPU_points_z( inputs = (points_x, points_y, points_z, targets) assert gradcheck(apply_finufft3d2(fftshift, isign), inputs) + + + +def check_t2_backward( + N: int, + modifier: int, + fftshift: bool, + isign: int, + device: str, + points_or_targets: bool, +) -> None: + points = torch.rand((3, N + modifier), dtype=torch.float64).to(device) * 2 * np.pi + targets = torch.randn(N, N, N, dtype=torch.complex128).to(device) + + points.requires_grad = points_or_targets + targets.requires_grad = not points_or_targets + + inputs = (points, targets) + + def func(points, targets): + return pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + None, + dict(modeord=int(not fftshift), isign=isign), + ) + + assert gradcheck(func, inputs, atol=1e-5 * N) + + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_CPU_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cpu", True) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_cuda_values( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", False) + +@pytest.mark.parametrize("N", Ns) +@pytest.mark.parametrize("modifier", length_modifiers) +@pytest.mark.parametrize("fftshift", [False, True]) +@pytest.mark.parametrize("isign", [-1, 1]) +def test_t2_backward_cuda_points( + N: int, modifier: int, fftshift: bool, isign: int +) -> None: + check_t2_backward(N, modifier, fftshift, isign, "cuda", True) + diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index a5c945e..f4cbd98 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -9,6 +9,7 @@ # Case generation Ns = [ + 3, 5, 10, 15, @@ -92,3 +93,36 @@ def test_3d_t2_forward_CPU(N: int) -> None: assert l_inf_error < 1e-5 * N**1.5 assert l_2_error < 1e-5 * N**3 assert l_1_error < 1e-5 * N**4.5 + + +@pytest.mark.parametrize("N", Ns) +def test_t2_forward_CPU(N: int) -> None: + """ + Tests against implementations of the FFT by setting up a uniform grid + over which to call FINUFFT through the API. + """ + g = np.mgrid[:N, :N, :N] * 2 * np.pi / N + points = torch.from_numpy(g.reshape(g.shape[0], -1)) + + targets = torch.randn(*g[0].shape, dtype=torch.complex128) + + print("N is " + str(N)) + print("shape of points is " + str(points.shape)) + print("shape of targets is " + str(targets.shape)) + + finufft_out = pytorch_finufft.functional.finufft_type2.apply( + points, + targets, + ) + + against_torch = torch.fft.fftn(targets) + + abs_errors = torch.abs(finufft_out - against_torch.ravel()) + l_inf_error = abs_errors.max() + l_2_error = torch.sqrt(torch.sum(abs_errors**2)) + l_1_error = torch.sum(abs_errors) + + assert l_inf_error < 4.5e-5 * N ** 1.1 + assert l_2_error < 6e-5 * N ** 2.1 + assert l_1_error < 1.2e-4 * N ** 3.2 +