diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index 479959f..ca42977 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -71,13 +71,13 @@ def coordinate_ramps(shape, device): return coord_ramps -class finufft_type1(torch.autograd.Function): +class FinufftType1(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] ctx: Any, points: torch.Tensor, values: torch.Tensor, - output_shape: Union[int, Tuple[int, int], Tuple[int, int, int]], + output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]], finufftkwargs: Optional[Dict[str, Union[int, float]]] = None, ) -> torch.Tensor: """ @@ -181,7 +181,7 @@ def backward( # type: ignore[override] ) -class finufft_type2(torch.autograd.Function): +class FinufftType2(torch.autograd.Function): """ FINUFFT 2D problem type 2 """ @@ -324,3 +324,80 @@ def backward( # type: ignore[override] None, None, ) + + +def finufft_type1( + points: torch.Tensor, + values: torch.Tensor, + output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]], + **finufftkwargs: Union[int, float], +) -> torch.Tensor: + """ + Evaluates the Type 1 (nonuniform-to-uniform) NUFFT on the inputs. + + This is a wrapper around :func:`finufft.nufft1d1`, :func:`finufft.nufft2d1`, and + :func:`finufft.nufft3d1` on CPU, and :func:`cufinufft.nufft1d1`, + :func:`cufinufft.nufft2d1`, and :func:`cufinufft.nufft3d1` on GPU. + + Parameters + ---------- + points : torch.Tensor + DxN tensor of locations of the non-uniform points. + Points must lie in the range ``[-3pi, 3pi]``. + values : torch.Tensor + Length N complex-valued tensor of values at the non-uniform points + output_shape : int | tuple(int, ...) + Requested output shape of Fourier modes. Must be a tuple of length D or + an integer (1D only). + **finufftkwargs : int | float + Additional keyword arguments are forwarded to the underlying + FINUFFT functions. A few notable options are + + - ``eps``: precision requested (default: ``1e-6``) + - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``) + - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``) + + Returns + ------- + torch.Tensor + Tensor with shape ``output_shape`` containing the Fourier + transform of the values. + """ + res: torch.Tensor = FinufftType1.apply(points, values, output_shape, finufftkwargs) + return res + + +def finufft_type2( + points: torch.Tensor, + targets: torch.Tensor, + **finufftkwargs: Union[int, float], +) -> torch.Tensor: + """ + Evaluates the Type 2 (uniform-to-nonuniform) NUFFT on the inputs. + + This is a wrapper around :func:`finufft.nufft1d2`, :func:`finufft.nufft2d2`, and + :func:`finufft.nufft3d2` on CPU, and :func:`cufinufft.nufft1d2`, + :func:`cufinufft.nufft2d2`, and :func:`cufinufft.nufft3d2` on GPU. + + Parameters + ---------- + points : torch.Tensor + DxN tensor of locations of the non-uniform points. + Points must lie in the range ``[-3pi, 3pi]``. + targets : torch.Tensor + D-dimensional complex-valued tensor of Fourier modes to evaluate at the points + **finufftkwargs : int | float + Additional keyword arguments are forwarded to the underlying + FINUFFT functions. A few notable options are + + - ``eps``: precision requested (default: ``1e-6``) + - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``) + - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``) + + Returns + ------- + torch.Tensor + A DxN tensor of values at the non-uniform points. + """ + res: torch.Tensor = FinufftType2.apply(points, targets, finufftkwargs) + return res diff --git a/tests/test_1d/test_backward_1d.py b/tests/test_1d/test_backward_1d.py index 9753ad3..7e6da9c 100644 --- a/tests/test_1d/test_backward_1d.py +++ b/tests/test_1d/test_backward_1d.py @@ -43,11 +43,12 @@ def check_t1_backward( inputs = (points, values) def func(points, values): - return pytorch_finufft.functional.finufft_type1.apply( + return pytorch_finufft.functional.finufft_type1( points, values, (N + modifier,), - dict(modeord=int(not fftshift), isign=isign), + modeord=int(not fftshift), + isign=isign, ) assert gradcheck(func, inputs, eps=1e-8, atol=1e-5 * N) @@ -110,10 +111,11 @@ def check_t2_backward( inputs = (points, targets) def func(points, targets): - return pytorch_finufft.functional.finufft_type2.apply( + return pytorch_finufft.functional.finufft_type2( points, targets, - dict(modeord=int(not fftshift), isign=isign), + modeord=int(not fftshift), + isign=isign, ) assert gradcheck(func, inputs, eps=1e-8, atol=1.5e-3 * N) diff --git a/tests/test_1d/test_forward_1d.py b/tests/test_1d/test_forward_1d.py index a57e537..aa2bc19 100644 --- a/tests/test_1d/test_forward_1d.py +++ b/tests/test_1d/test_forward_1d.py @@ -36,7 +36,7 @@ def check_t1_forward(N: int, device: str) -> None: print("shape of points is " + str(points.shape)) print("shape of values is " + str(values.shape)) - finufft_out = pytorch_finufft.functional.finufft_type1.apply( + finufft_out = pytorch_finufft.functional.finufft_type1( points, values, (N,), @@ -89,7 +89,7 @@ def test_t2_forward_CPU(N: int) -> None: print("shape of points is " + str(points.shape)) print("shape of targets is " + str(targets.shape)) - finufft_out = pytorch_finufft.functional.finufft_type2.apply( + finufft_out = pytorch_finufft.functional.finufft_type2( points, targets, ) diff --git a/tests/test_2d/test_backward_2d.py b/tests/test_2d/test_backward_2d.py index 687ce67..d47b285 100644 --- a/tests/test_2d/test_backward_2d.py +++ b/tests/test_2d/test_backward_2d.py @@ -54,11 +54,8 @@ def check_t1_backward( inputs = (points, values) def func(points, values): - return pytorch_finufft.functional.finufft_type1.apply( - points, - values, - (N, N + modifier), - dict(modeord=int(not fftshift), isign=isign), + return pytorch_finufft.functional.finufft_type1( + points, values, (N, N + modifier), modeord=int(not fftshift), isign=isign ) assert gradcheck(func, inputs, atol=1e-5 * N) @@ -121,10 +118,11 @@ def check_t2_backward( inputs = (points, targets) def func(points, targets): - return pytorch_finufft.functional.finufft_type2.apply( + return pytorch_finufft.functional.finufft_type2( points, targets, - dict(modeord=int(not fftshift), isign=isign), + modeord=int(not fftshift), + isign=isign, ) assert gradcheck(func, inputs, atol=1e-5 * N) diff --git a/tests/test_2d/test_forward_2d.py b/tests/test_2d/test_forward_2d.py index c8f7640..1107f88 100644 --- a/tests/test_2d/test_forward_2d.py +++ b/tests/test_2d/test_forward_2d.py @@ -35,7 +35,7 @@ def check_t1_forward(N: int, device: str) -> None: print("shape of points is " + str(points.shape)) print("shape of values is " + str(values.shape)) - finufft_out = pytorch_finufft.functional.finufft_type1.apply( + finufft_out = pytorch_finufft.functional.finufft_type1( points, values, (N, N), @@ -79,10 +79,10 @@ def test_t2_forward_CPU(N: int, fftshift: bool) -> None: print("shape of points is " + str(points.shape)) print("shape of targets is " + str(targets.shape)) - finufft_out = pytorch_finufft.functional.finufft_type2.apply( + finufft_out = pytorch_finufft.functional.finufft_type2( points, targets, - {"modeord": int(not fftshift)}, + modeord=int(not fftshift), ) if fftshift: diff --git a/tests/test_3d/test_backward_3d.py b/tests/test_3d/test_backward_3d.py index 3eb324d..acc1088 100644 --- a/tests/test_3d/test_backward_3d.py +++ b/tests/test_3d/test_backward_3d.py @@ -46,11 +46,12 @@ def check_t1_backward( inputs = (points, values) def func(points, values): - return pytorch_finufft.functional.finufft_type1.apply( + return pytorch_finufft.functional.finufft_type1( points, values, (N, N + modifier, N + 2 * modifier), - dict(modeord=int(not fftshift), isign=isign), + modeord=int(not fftshift), + isign=isign, ) assert gradcheck(func, inputs, eps=1e-8, atol=1e-5 * N) @@ -118,10 +119,8 @@ def check_t2_backward( inputs = (points, targets) def func(points, targets): - return pytorch_finufft.functional.finufft_type2.apply( - points, - targets, - dict(modeord=int(not fftshift), isign=isign), + return pytorch_finufft.functional.finufft_type2( + points, targets, modeord=int(not fftshift), isign=isign ) assert gradcheck(func, inputs, atol=1e-5 * N) diff --git a/tests/test_3d/test_forward_3d.py b/tests/test_3d/test_forward_3d.py index 6ac24ae..d663075 100644 --- a/tests/test_3d/test_forward_3d.py +++ b/tests/test_3d/test_forward_3d.py @@ -34,7 +34,7 @@ def check_t1_forward(N: int, device: str) -> None: print("shape of points is " + str(points.shape)) print("shape of values is " + str(values.shape)) - finufft_out = pytorch_finufft.functional.finufft_type1.apply( + finufft_out = pytorch_finufft.functional.finufft_type1( points, values, (N, N, N), @@ -77,7 +77,7 @@ def test_t2_forward_CPU(N: int) -> None: print("shape of points is " + str(points.shape)) print("shape of targets is " + str(targets.shape)) - finufft_out = pytorch_finufft.functional.finufft_type2.apply( + finufft_out = pytorch_finufft.functional.finufft_type2( points, targets, ) diff --git a/tests/test_errors.py b/tests/test_errors.py index 93421fb..4ab3d72 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -15,7 +15,7 @@ def test_t1_mismatch_device_cuda_cpu() -> None: values = torch.randn(10, dtype=torch.complex128).to("cuda:0") with pytest.raises(ValueError, match="Some tensors are not on the same device"): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="require multiple GPUs") @@ -24,7 +24,7 @@ def test_t1_mismatch_cuda_index() -> None: values = torch.randn(10, dtype=torch.complex128).to("cuda:1") with pytest.raises(ValueError, match="Some tensors are not on the same device"): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) def test_t2_mismatch_device_cuda_cpu() -> None: @@ -33,7 +33,7 @@ def test_t2_mismatch_device_cuda_cpu() -> None: targets = torch.randn(*g[0].shape, dtype=torch.complex128).to("cuda:0") with pytest.raises(ValueError, match="Some tensors are not on the same device"): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="require multiple GPUs") @@ -43,7 +43,7 @@ def test_t2_mismatch_cuda_index() -> None: targets = torch.randn(*g[0].shape, dtype=torch.complex128).to("cuda:1") with pytest.raises(ValueError, match="Some tensors are not on the same device"): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) # dtypes @@ -57,7 +57,7 @@ def test_t1_non_complex_values() -> None: TypeError, match="Values must have a dtype of torch.complex64 or torch.complex128", ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) def test_t1_half_complex_values() -> None: @@ -71,7 +71,7 @@ def test_t1_half_complex_values() -> None: TypeError, match="Values must have a dtype of torch.complex64 or torch.complex128", ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) def test_t1_non_real_points() -> None: @@ -83,7 +83,7 @@ def test_t1_non_real_points() -> None: match="Points must have a dtype of torch.float64 as values has " "a dtype of torch.complex128", ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) def test_t1_mismatch_precision() -> None: @@ -95,7 +95,7 @@ def test_t1_mismatch_precision() -> None: match="Points must have a dtype of torch.float64 as values has " "a dtype of torch.complex128", ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) points = torch.rand((2, 10), dtype=torch.float64) values = torch.randn(10, dtype=torch.complex64) @@ -105,7 +105,7 @@ def test_t1_mismatch_precision() -> None: match="Points must have a dtype of torch.float32 as values has " "a dtype of torch.complex64", ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) def test_t2_non_complex_targets() -> None: @@ -117,7 +117,7 @@ def test_t2_non_complex_targets() -> None: TypeError, match="Targets must have a dtype of torch.complex64 or torch.complex128", ): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) def test_t2_half_complex_targets() -> None: @@ -132,7 +132,7 @@ def test_t2_half_complex_targets() -> None: TypeError, match="Targets must have a dtype of torch.complex64 or torch.complex128", ): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) def test_t2_non_real_points() -> None: @@ -145,7 +145,7 @@ def test_t2_non_real_points() -> None: match="Points must have a dtype of torch.float64 as targets has " "a dtype of torch.complex128", ): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) def test_t2_mismatch_precision() -> None: @@ -158,7 +158,7 @@ def test_t2_mismatch_precision() -> None: match="Points must have a dtype of torch.float64 as targets has " "a dtype of torch.complex128", ): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) points = points.to(torch.float64) targets = targets.to(torch.complex64) @@ -168,7 +168,7 @@ def test_t2_mismatch_precision() -> None: match="Points must have a dtype of torch.float32 as targets has " "a dtype of torch.complex64", ): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) # sizes @@ -181,14 +181,14 @@ def test_t1_wrong_length() -> None: with pytest.raises( ValueError, match="The same number of points and values must be supplied" ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10,)) + pytorch_finufft.functional.finufft_type1(points, values, (10,)) points = torch.rand((3, 10), dtype=torch.float64) with pytest.raises( ValueError, match="The same number of points and values must be supplied" ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10,)) + pytorch_finufft.functional.finufft_type1(points, values, (10,)) def test_t1_points_4d() -> None: @@ -196,7 +196,7 @@ def test_t1_points_4d() -> None: values = torch.randn(10, dtype=torch.complex128) with pytest.raises(ValueError, match="Points can be at most 3d, got"): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) def test_t1_too_many_points_dims() -> None: @@ -204,7 +204,7 @@ def test_t1_too_many_points_dims() -> None: values = torch.randn(10, dtype=torch.complex128) with pytest.raises(ValueError, match="The points tensor must be 1d or 2d"): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10)) def test_t1_wrong_output_dims() -> None: @@ -214,17 +214,17 @@ def test_t1_wrong_output_dims() -> None: with pytest.raises( ValueError, match="output_shape must be of length 2 for 2d NUFFT" ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, 10, 10)) + pytorch_finufft.functional.finufft_type1(points, values, (10, 10, 10)) with pytest.raises( ValueError, match="output_shape must be of length 2 for 2d NUFFT" ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10,)) + pytorch_finufft.functional.finufft_type1(points, values, (10,)) with pytest.raises( ValueError, match="output_shape must be a tuple of length 2 for 2d NUFFT" ): - pytorch_finufft.functional.finufft_type1.apply(points, values, 10) + pytorch_finufft.functional.finufft_type1(points, values, 10) def test_t1_negative_output_dims() -> None: @@ -234,13 +234,13 @@ def test_t1_negative_output_dims() -> None: with pytest.raises( ValueError, match="Got output_shape that was not positive integer" ): - pytorch_finufft.functional.finufft_type1.apply(points, values, 0) + pytorch_finufft.functional.finufft_type1(points, values, 0) points = torch.rand((2, 10), dtype=torch.float64) with pytest.raises( ValueError, match="Got output_shape that was not positive integer" ): - pytorch_finufft.functional.finufft_type1.apply(points, values, (10, -2)) + pytorch_finufft.functional.finufft_type1(points, values, (10, -2)) def test_t2_points_4d() -> None: @@ -249,7 +249,7 @@ def test_t2_points_4d() -> None: targets = torch.randn(*g[0].shape, dtype=torch.complex128) with pytest.raises(ValueError, match="Points can be at most 3d, got"): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) def test_t2_too_many_points_dims() -> None: @@ -258,7 +258,7 @@ def test_t2_too_many_points_dims() -> None: targets = torch.randn(*g[0].shape, dtype=torch.complex128) with pytest.raises(ValueError, match="The points tensor must be 1d or 2d"): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) def test_t2_mismatch_dims() -> None: @@ -269,7 +269,7 @@ def test_t2_mismatch_dims() -> None: with pytest.raises( ValueError, match="For type 2 3d FINUFFT, targets must be a 3d tensor" ): - pytorch_finufft.functional.finufft_type2.apply(points, targets) + pytorch_finufft.functional.finufft_type2(points, targets) # dependencies @@ -281,11 +281,11 @@ def test_finufft_not_installed(): values = torch.randn(10, dtype=torch.complex128).to("cuda") with pytest.raises(RuntimeError, match="cufinufft failed to import"): - pytorch_finufft.functional.finufft_type1.apply(points, values, 10) + pytorch_finufft.functional.finufft_type1(points, values, 10) elif not pytorch_finufft.functional.FINUFFT_AVAIL: points = torch.rand(10, dtype=torch.float64).to("cpu") values = torch.randn(10, dtype=torch.complex128).to("cpu") with pytest.raises(RuntimeError, match="finufft failed to import"): - pytorch_finufft.functional.finufft_type1.apply(points, values, 10) + pytorch_finufft.functional.finufft_type1(points, values, 10)