diff --git a/.buildinfo b/.buildinfo new file mode 100644 index 0000000..2ca8d08 --- /dev/null +++ b/.buildinfo @@ -0,0 +1,4 @@ +# Sphinx build info version 1 +# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. +config: 8f9da24b5a9f7f395c81a6ce297fc0c7 +tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/_downloads/2db4d350e416e3adf54e890f00153125/convolution_2d.py b/_downloads/2db4d350e416e3adf54e890f00153125/convolution_2d.py new file mode 100644 index 0000000..d746c5b --- /dev/null +++ b/_downloads/2db4d350e416e3adf54e890f00153125/convolution_2d.py @@ -0,0 +1,235 @@ +""" +Convolution in 2D +================= +""" + + +####################################################################################### +# Import packages +# --------------- +# +# First, we import the packages we need for this example. + +import matplotlib.pyplot as plt +import numpy as np +import torch + +import pytorch_finufft + +####################################################################################### +# Let's create a Gaussian convolutional filter as a function of x,y + + +def gaussian_function(x, y, sigma=1): + return np.exp(-(x**2 + y**2) / (2 * sigma**2)) + + +####################################################################################### +# Let's visualize this filter kernel. We will be using it to convolve with points +# living on the $[0, 2*\pi] \times [0, 2*\pi]$ torus. So let's dimension it accordingly. + +shape = (128, 128) +sigma = 0.5 +x = np.linspace(-np.pi, np.pi, shape[0], endpoint=False) +y = np.linspace(-np.pi, np.pi, shape[1], endpoint=False) + +gaussian_kernel = gaussian_function(x[:, np.newaxis], y, sigma=sigma) + +fig, ax = plt.subplots() +_ = ax.imshow(gaussian_kernel) + +####################################################################################### +# In order for the kernel to not shift the signal, we need to place its mass at 0. +# To do this, we ifftshift the kernel + +shifted_gaussian_kernel = np.fft.ifftshift(gaussian_kernel) + +fig, ax = plt.subplots() +_ = ax.imshow(shifted_gaussian_kernel) + + +####################################################################################### +# Now let's create a point cloud on the torus that we can convolve with our filter + +N = 20 +points = np.random.rand(2, N) * 2 * np.pi + +fig, ax = plt.subplots() +ax.set_xlim(0, 2 * np.pi) +ax.set_ylim(0, 2 * np.pi) +ax.set_aspect("equal") +_ = ax.scatter(points[0], points[1], s=1) + + +####################################################################################### +# Now we can convolve the point cloud with the filter kernel. +# To do this, we Fourier-transform both the point cloud and the filter kernel, +# multiply them together, and then inverse Fourier-transform the result. +# First we need to convert all data to torch tensors + +fourier_shifted_gaussian_kernel = torch.fft.fft2( + torch.from_numpy(shifted_gaussian_kernel) +) +fourier_points = pytorch_finufft.functional.finufft_type1( + torch.from_numpy(points), torch.ones(points.shape[1], dtype=torch.complex128), shape +) + +fig, axs = plt.subplots(1, 3) +axs[0].imshow(fourier_shifted_gaussian_kernel.real) +axs[1].imshow(fourier_points.real, vmin=-10, vmax=10) +_ = axs[2].imshow( + ( + fourier_points + * fourier_shifted_gaussian_kernel + / fourier_shifted_gaussian_kernel[0, 0] + ).real, + vmin=-10, + vmax=10, +) + +####################################################################################### +# We now have two possibilities: Invert the Fourier transform on a grid, or on a point +# cloud. We'll first invert the Fourier transform on a grid in order to be able to +# visualize the effect of the convolution. + +convolved_points = torch.fft.ifft2(fourier_points * fourier_shifted_gaussian_kernel) + +fig, ax = plt.subplots() +ax.imshow(convolved_points.real) +_ = ax.scatter( + points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c="r" +) + +####################################################################################### +# We see that the convolution has smeared out the point cloud. +# After a small coordinate change, we can also plot the original points +# on the same plot as the convolved points. + + +####################################################################################### +# Next, we invert the Fourier transform on the same points as +# our original point cloud. We will then compare this to direct evaluation +# of the kernel on all pairwise difference vectors between the points. + +convolved_at_points = pytorch_finufft.functional.finufft_type2( + torch.from_numpy(points), + fourier_points * fourier_shifted_gaussian_kernel, + isign=1, +).real / np.prod(shape) + +fig, ax = plt.subplots() +ax.imshow(convolved_points.real) +_ = ax.scatter( + points[1] / 2 / np.pi * shape[0], + points[0] / 2 / np.pi * shape[1], + s=10 * convolved_at_points, + c="r", +) + +####################################################################################### +# To compute the convolution directly, we need to evaluate the kernel on all pairwise +# difference vectors between the points. Note the points that will be off the diagonal. +# These will be due to the periodic boundary conditions of the convolution. + +pairwise_diffs = points[:, np.newaxis] - points[:, :, np.newaxis] +kernel_diff_evals = gaussian_function(*pairwise_diffs, sigma=sigma) +convolved_by_hand = kernel_diff_evals.sum(1) + +fig, ax = plt.subplots() +ax.plot(convolved_at_points.numpy(), convolved_by_hand, ".") +ax.plot([1, 3], [1, 3]) + +relative_difference = torch.norm( + convolved_at_points - convolved_by_hand +) / np.linalg.norm(convolved_by_hand) +print( + "Relative difference between fourier convolution and direct convolution " + f"{relative_difference}" +) + + +####################################################################################### +# Now let's see if we can learn the convolution kernel from the input and output point +# clouds. To this end, let's first make a pytorch object that can compute a kernel +# convolution on a point cloud. + + +class FourierPointConvolution(torch.nn.Module): + def __init__(self, fourier_kernel_shape): + super().__init__() + self.fourier_kernel_shape = fourier_kernel_shape + + self.build() + + def build(self): + self.register_parameter( + "fourier_kernel", + torch.nn.Parameter( + torch.randn(self.fourier_kernel_shape, dtype=torch.complex128) + ), + ) + # ^ think about whether we need to scale this init in some better way + + def forward(self, points, values): + fourier_transformed_input = pytorch_finufft.functional.finufft_type1( + points, values, self.fourier_kernel_shape + ) + fourier_convolved = fourier_transformed_input * self.fourier_kernel + convolved = pytorch_finufft.functional.finufft_type2( + points, + fourier_convolved, + isign=1, + ).real / np.prod(self.fourier_kernel_shape) + return convolved + + +####################################################################################### +# Now we can use this object in a pytorch training loop to learn the kernel from the +# input and output point clouds. We will use the mean squared error as a loss function. + +fourier_point_convolution = FourierPointConvolution(shape) +optimizer = torch.optim.AdamW( + fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001 +) + +ones = torch.ones(points.shape[1], dtype=torch.complex128) + +losses = [] +for i in range(10000): + # Make new set of points and compute forward model + points = np.random.rand(2, N) * 2 * np.pi + torch_points = torch.from_numpy(points) + fourier_points = pytorch_finufft.functional.finufft_type1( + torch.from_numpy(points), ones, shape + ) + convolved_at_points = pytorch_finufft.functional.finufft_type2( + torch.from_numpy(points), + fourier_points * fourier_shifted_gaussian_kernel, + isign=1, + ).real / np.prod(shape) + + # Learning step + optimizer.zero_grad() + convolved = fourier_point_convolution(torch_points, ones) + loss = torch.nn.functional.mse_loss(convolved, convolved_at_points) + losses.append(loss.item()) + loss.backward() + optimizer.step() + + if i % 100 == 0: + print(f"Iteration {i:05d}, Loss: {loss.item():1.4f}") + + +fig, ax = plt.subplots() +ax.plot(losses) +ax.set_ylabel("Loss") +ax.set_xlabel("Iteration") +ax.set_yscale("log") + +fig, ax = plt.subplots() +im = ax.imshow( + torch.real(torch.fft.fftshift(fourier_point_convolution.fourier_kernel.data))[ + 48:80, 48:80 + ] +) +_ = fig.colorbar(im, ax=ax) diff --git a/_downloads/55cd874c9c3339d444a806fb4ce70cc6/convolution_2d.ipynb b/_downloads/55cd874c9c3339d444a806fb4ce70cc6/convolution_2d.ipynb new file mode 100644 index 0000000..f43d87b --- /dev/null +++ b/_downloads/55cd874c9c3339d444a806fb4ce70cc6/convolution_2d.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Convolution in 2D\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import packages\n\nFirst, we import the packages we need for this example.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nimport pytorch_finufft" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a Gaussian convolutional filter as a function of x,y\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def gaussian_function(x, y, sigma=1):\n return np.exp(-(x**2 + y**2) / (2 * sigma**2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize this filter kernel. We will be using it to convolve with points\nliving on the $[0, 2*\\pi] \\times [0, 2*\\pi]$ torus. So let's dimension it accordingly.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "shape = (128, 128)\nsigma = 0.5\nx = np.linspace(-np.pi, np.pi, shape[0], endpoint=False)\ny = np.linspace(-np.pi, np.pi, shape[1], endpoint=False)\n\ngaussian_kernel = gaussian_function(x[:, np.newaxis], y, sigma=sigma)\n\nfig, ax = plt.subplots()\n_ = ax.imshow(gaussian_kernel)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order for the kernel to not shift the signal, we need to place its mass at 0.\nTo do this, we ifftshift the kernel\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "shifted_gaussian_kernel = np.fft.ifftshift(gaussian_kernel)\n\nfig, ax = plt.subplots()\n_ = ax.imshow(shifted_gaussian_kernel)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's create a point cloud on the torus that we can convolve with our filter\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "N = 20\npoints = np.random.rand(2, N) * 2 * np.pi\n\nfig, ax = plt.subplots()\nax.set_xlim(0, 2 * np.pi)\nax.set_ylim(0, 2 * np.pi)\nax.set_aspect(\"equal\")\n_ = ax.scatter(points[0], points[1], s=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can convolve the point cloud with the filter kernel.\nTo do this, we Fourier-transform both the point cloud and the filter kernel,\nmultiply them together, and then inverse Fourier-transform the result.\nFirst we need to convert all data to torch tensors\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fourier_shifted_gaussian_kernel = torch.fft.fft2(\n torch.from_numpy(shifted_gaussian_kernel)\n)\nfourier_points = pytorch_finufft.functional.finufft_type1(\n torch.from_numpy(points), torch.ones(points.shape[1], dtype=torch.complex128), shape\n)\n\nfig, axs = plt.subplots(1, 3)\naxs[0].imshow(fourier_shifted_gaussian_kernel.real)\naxs[1].imshow(fourier_points.real, vmin=-10, vmax=10)\n_ = axs[2].imshow(\n (\n fourier_points\n * fourier_shifted_gaussian_kernel\n / fourier_shifted_gaussian_kernel[0, 0]\n ).real,\n vmin=-10,\n vmax=10,\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now have two possibilities: Invert the Fourier transform on a grid, or on a point\ncloud. We'll first invert the Fourier transform on a grid in order to be able to\nvisualize the effect of the convolution.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "convolved_points = torch.fft.ifft2(fourier_points * fourier_shifted_gaussian_kernel)\n\nfig, ax = plt.subplots()\nax.imshow(convolved_points.real)\n_ = ax.scatter(\n points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c=\"r\"\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the convolution has smeared out the point cloud.\nAfter a small coordinate change, we can also plot the original points\non the same plot as the convolved points.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we invert the Fourier transform on the same points as\nour original point cloud. We will then compare this to direct evaluation\nof the kernel on all pairwise difference vectors between the points.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "convolved_at_points = pytorch_finufft.functional.finufft_type2(\n torch.from_numpy(points),\n fourier_points * fourier_shifted_gaussian_kernel,\n isign=1,\n).real / np.prod(shape)\n\nfig, ax = plt.subplots()\nax.imshow(convolved_points.real)\n_ = ax.scatter(\n points[1] / 2 / np.pi * shape[0],\n points[0] / 2 / np.pi * shape[1],\n s=10 * convolved_at_points,\n c=\"r\",\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To compute the convolution directly, we need to evaluate the kernel on all pairwise\ndifference vectors between the points. Note the points that will be off the diagonal.\nThese will be due to the periodic boundary conditions of the convolution.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "pairwise_diffs = points[:, np.newaxis] - points[:, :, np.newaxis]\nkernel_diff_evals = gaussian_function(*pairwise_diffs, sigma=sigma)\nconvolved_by_hand = kernel_diff_evals.sum(1)\n\nfig, ax = plt.subplots()\nax.plot(convolved_at_points.numpy(), convolved_by_hand, \".\")\nax.plot([1, 3], [1, 3])\n\nrelative_difference = torch.norm(\n convolved_at_points - convolved_by_hand\n) / np.linalg.norm(convolved_by_hand)\nprint(\n \"Relative difference between fourier convolution and direct convolution \"\n f\"{relative_difference}\"\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's see if we can learn the convolution kernel from the input and output point\nclouds. To this end, let's first make a pytorch object that can compute a kernel\nconvolution on a point cloud.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class FourierPointConvolution(torch.nn.Module):\n def __init__(self, fourier_kernel_shape):\n super().__init__()\n self.fourier_kernel_shape = fourier_kernel_shape\n\n self.build()\n\n def build(self):\n self.register_parameter(\n \"fourier_kernel\",\n torch.nn.Parameter(\n torch.randn(self.fourier_kernel_shape, dtype=torch.complex128)\n ),\n )\n # ^ think about whether we need to scale this init in some better way\n\n def forward(self, points, values):\n fourier_transformed_input = pytorch_finufft.functional.finufft_type1(\n points, values, self.fourier_kernel_shape\n )\n fourier_convolved = fourier_transformed_input * self.fourier_kernel\n convolved = pytorch_finufft.functional.finufft_type2(\n points,\n fourier_convolved,\n isign=1,\n ).real / np.prod(self.fourier_kernel_shape)\n return convolved" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can use this object in a pytorch training loop to learn the kernel from the\ninput and output point clouds. We will use the mean squared error as a loss function.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fourier_point_convolution = FourierPointConvolution(shape)\noptimizer = torch.optim.AdamW(\n fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001\n)\n\nones = torch.ones(points.shape[1], dtype=torch.complex128)\n\nlosses = []\nfor i in range(10000):\n # Make new set of points and compute forward model\n points = np.random.rand(2, N) * 2 * np.pi\n torch_points = torch.from_numpy(points)\n fourier_points = pytorch_finufft.functional.finufft_type1(\n torch.from_numpy(points), ones, shape\n )\n convolved_at_points = pytorch_finufft.functional.finufft_type2(\n torch.from_numpy(points),\n fourier_points * fourier_shifted_gaussian_kernel,\n isign=1,\n ).real / np.prod(shape)\n\n # Learning step\n optimizer.zero_grad()\n convolved = fourier_point_convolution(torch_points, ones)\n loss = torch.nn.functional.mse_loss(convolved, convolved_at_points)\n losses.append(loss.item())\n loss.backward()\n optimizer.step()\n\n if i % 100 == 0:\n print(f\"Iteration {i:05d}, Loss: {loss.item():1.4f}\")\n\n\nfig, ax = plt.subplots()\nax.plot(losses)\nax.set_ylabel(\"Loss\")\nax.set_xlabel(\"Iteration\")\nax.set_yscale(\"log\")\n\nfig, ax = plt.subplots()\nim = ax.imshow(\n torch.real(torch.fft.fftshift(fourier_point_convolution.fourier_kernel.data))[\n 48:80, 48:80\n ]\n)\n_ = fig.colorbar(im, ax=ax)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/_images/sphx_glr_convolution_2d_001.png b/_images/sphx_glr_convolution_2d_001.png new file mode 100644 index 0000000..5dd9804 Binary files /dev/null and b/_images/sphx_glr_convolution_2d_001.png differ diff --git a/_images/sphx_glr_convolution_2d_002.png b/_images/sphx_glr_convolution_2d_002.png new file mode 100644 index 0000000..851878d Binary files /dev/null and b/_images/sphx_glr_convolution_2d_002.png differ diff --git a/_images/sphx_glr_convolution_2d_003.png b/_images/sphx_glr_convolution_2d_003.png new file mode 100644 index 0000000..87d5bfe Binary files /dev/null and b/_images/sphx_glr_convolution_2d_003.png differ diff --git a/_images/sphx_glr_convolution_2d_004.png b/_images/sphx_glr_convolution_2d_004.png new file mode 100644 index 0000000..a2f88fa Binary files /dev/null and b/_images/sphx_glr_convolution_2d_004.png differ diff --git a/_images/sphx_glr_convolution_2d_005.png b/_images/sphx_glr_convolution_2d_005.png new file mode 100644 index 0000000..a09f227 Binary files /dev/null and b/_images/sphx_glr_convolution_2d_005.png differ diff --git a/_images/sphx_glr_convolution_2d_006.png b/_images/sphx_glr_convolution_2d_006.png new file mode 100644 index 0000000..f682f83 Binary files /dev/null and b/_images/sphx_glr_convolution_2d_006.png differ diff --git a/_images/sphx_glr_convolution_2d_007.png b/_images/sphx_glr_convolution_2d_007.png new file mode 100644 index 0000000..54ccd18 Binary files /dev/null and b/_images/sphx_glr_convolution_2d_007.png differ diff --git a/_images/sphx_glr_convolution_2d_008.png b/_images/sphx_glr_convolution_2d_008.png new file mode 100644 index 0000000..611730f Binary files /dev/null and b/_images/sphx_glr_convolution_2d_008.png differ diff --git a/_images/sphx_glr_convolution_2d_009.png b/_images/sphx_glr_convolution_2d_009.png new file mode 100644 index 0000000..bd9ec25 Binary files /dev/null and b/_images/sphx_glr_convolution_2d_009.png differ diff --git a/_images/sphx_glr_convolution_2d_thumb.png b/_images/sphx_glr_convolution_2d_thumb.png new file mode 100644 index 0000000..2969af0 Binary files /dev/null and b/_images/sphx_glr_convolution_2d_thumb.png differ diff --git a/_modules/index.html b/_modules/index.html new file mode 100644 index 0000000..3e4be73 --- /dev/null +++ b/_modules/index.html @@ -0,0 +1,426 @@ + + + + + + +
+ + +
+"""
+Implementations of the corresponding Autograd functions
+"""
+
+import warnings
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+
+try:
+ import finufft
+
+ FINUFFT_AVAIL = True
+except ImportError:
+ FINUFFT_AVAIL = False
+
+try:
+ import cufinufft
+
+ if cufinufft.__version__.startswith("1."):
+ warnings.warn("pytorch-finufft does not support cufinufft v1.x.x")
+ else:
+ CUFINUFFT_AVAIL = True
+except ImportError:
+ CUFINUFFT_AVAIL = False
+
+if not (FINUFFT_AVAIL or CUFINUFFT_AVAIL):
+ raise ImportError(
+ "No FINUFFT implementation available. "
+ "Install either finufft or cufinufft and ensure they are importable."
+ )
+
+import pytorch_finufft.checks as checks
+
+newaxis = None
+
+
+def get_nufft_func(
+ dim: int, nufft_type: int, device_type: str
+) -> Callable[..., torch.Tensor]:
+ if device_type == "cuda":
+ if not CUFINUFFT_AVAIL:
+ raise RuntimeError("CUDA device requested but cufinufft failed to import")
+ return getattr(cufinufft, f"nufft{dim}d{nufft_type}") # type: ignore
+
+ if not FINUFFT_AVAIL:
+ raise RuntimeError("CPU device requested but finufft failed to import")
+ # CPU needs extra work to go to/from torch and numpy
+ finufft_func = getattr(finufft, f"nufft{dim}d{nufft_type}")
+
+ def f(*args, **kwargs):
+ new_args = [arg for arg in args]
+ for i in range(len(new_args)):
+ if isinstance(new_args[i], torch.Tensor):
+ new_args[i] = new_args[i].data.numpy()
+
+ return torch.from_numpy(finufft_func(*new_args, **kwargs))
+
+ return f
+
+
+def coordinate_ramps(shape, device):
+ start_points = -(torch.tensor(shape, device=device) // 2)
+ end_points = start_points + torch.tensor(shape, device=device)
+ coord_ramps = torch.stack(
+ torch.meshgrid(
+ *(
+ torch.arange(start, end, device=device)
+ for start, end in zip(start_points, end_points)
+ ),
+ indexing="ij",
+ )
+ )
+
+ return coord_ramps
+
+
+class FinufftType1(torch.autograd.Function):
+ @staticmethod
+ def forward( # type: ignore[override]
+ ctx: Any,
+ points: torch.Tensor,
+ values: torch.Tensor,
+ output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
+ finufftkwargs: Optional[Dict[str, Union[int, float]]] = None,
+ ) -> torch.Tensor:
+ """
+ Evaluates the Type 1 NUFFT on the inputs.
+ """
+
+ checks.check_devices(values, points)
+ checks.check_dtypes(values, points, "Values")
+ checks.check_sizes_t1(values, points)
+ points = torch.atleast_2d(points)
+ ndim = points.shape[0]
+ checks.check_output_shape(ndim, output_shape)
+
+ ctx.save_for_backward(points, values)
+
+ if finufftkwargs is None:
+ finufftkwargs = {}
+ else: # copy to avoid mutating caller's dictionary
+ finufftkwargs = {k: v for k, v in finufftkwargs.items()}
+ ctx.isign = finufftkwargs.pop("isign", -1) # note: FINUFFT default is 1
+ ctx.mode_ordering = finufftkwargs.pop(
+ "modeord", 1
+ ) # note: FINUFFT default is 0
+ ctx.finufftkwargs = finufftkwargs
+
+ nufft_func = get_nufft_func(ndim, 1, points.device.type)
+ finufft_out = nufft_func(
+ *points, values, output_shape, isign=ctx.isign, **ctx.finufftkwargs
+ )
+
+ # because modeord is missing from cufinufft
+ if ctx.mode_ordering:
+ finufft_out = torch.fft.ifftshift(finufft_out)
+
+ return finufft_out
+
+ @staticmethod
+ def backward( # type: ignore[override]
+ ctx: Any, grad_output: torch.Tensor
+ ) -> Tuple[Union[torch.Tensor, None], ...]:
+ """
+ Implements derivatives wrt. each argument in the forward method.
+
+ Parameters
+ ----------
+ ctx : Any
+ Pytorch context object.
+ grad_output : torch.Tensor
+ Backpass gradient output
+
+ Returns
+ -------
+ Tuple[Union[torch.Tensor, None], ...]
+ A tuple of derivatives wrt. each argument in the forward method
+ """
+ _i_sign = -1 * ctx.isign
+ _mode_ordering = ctx.mode_ordering
+ finufftkwargs = ctx.finufftkwargs
+
+ points, values = ctx.saved_tensors
+ device = points.device
+
+ grads_points = None
+ grad_values = None
+
+ ndim = points.shape[0]
+
+ nufft_func = get_nufft_func(ndim, 2, device.type)
+
+ if any(ctx.needs_input_grad) and _mode_ordering:
+ grad_output = torch.fft.fftshift(grad_output)
+
+ if ctx.needs_input_grad[0]:
+ # wrt points
+ coord_ramps = coordinate_ramps(grad_output.shape, device)
+
+ # we can't batch in 1d case so we squeeze and fix up the ouput later
+ ramped_grad_output = (
+ coord_ramps * grad_output[newaxis] * 1j * _i_sign
+ ).squeeze()
+ backprop_ramp = nufft_func(
+ *points, ramped_grad_output, isign=_i_sign, **finufftkwargs
+ )
+ grads_points = torch.atleast_2d((backprop_ramp.conj() * values).real)
+
+ if ctx.needs_input_grad[1]:
+ grad_values = nufft_func(
+ *points,
+ grad_output,
+ isign=_i_sign,
+ **finufftkwargs,
+ )
+
+ return (
+ grads_points,
+ grad_values,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FinufftType2(torch.autograd.Function):
+ """
+ FINUFFT 2D problem type 2
+ """
+
+ @staticmethod
+ def forward( # type: ignore[override]
+ ctx: Any,
+ points: torch.Tensor,
+ targets: torch.Tensor,
+ finufftkwargs: Optional[Dict[str, Union[int, float]]] = None,
+ ) -> torch.Tensor:
+ """
+ Evaluates the Type 2 NUFFT on the inputs.
+
+ NOTE: By default, the ordering is set to match that of Pytorch,
+ Numpy, and Scipy's FFT APIs. To match the mode ordering
+ native to FINUFFT, add {'modeord': 0} to finufftkwargs.
+
+ Parameters
+ ----------
+ ctx : Any
+ Pytorch context objecy
+ points : torch.Tensor, shape=(ndim, num_points)
+ The non-uniform points x
+ targets : torch.Tensor
+ The values on the input grid
+ out : Optional[torch.Tensor], optional
+ Array to take the result in-place, by default None
+ finufftkwargs : Dict[str, Union[int, float]]
+ Additional arguments will be passed into FINUFFT. See
+ https://finufft.readthedocs.io/en/latest/python.html.
+
+ Returns
+ -------
+ torch.Tensor
+ The Fourier transform of the targets grid evaluated at the points `points`
+
+ Raises
+ ------
+
+ """
+ checks.check_devices(targets, points)
+ checks.check_dtypes(targets, points, "Targets")
+ checks.check_sizes_t2(targets, points)
+
+ if finufftkwargs is None:
+ finufftkwargs = dict()
+ finufftkwargs = {k: v for k, v in finufftkwargs.items()}
+ _mode_ordering = finufftkwargs.pop(
+ "modeord", 1
+ ) # not finufft default, but corresponds to pytorch default
+ _i_sign = finufftkwargs.pop(
+ "isign", -1
+ ) # isign=-1 is finufft default for type 2
+
+ points = torch.atleast_2d(points)
+ if _mode_ordering:
+ targets = torch.fft.fftshift(targets)
+
+ ctx.save_for_backward(points, targets)
+
+ ctx.isign = _i_sign
+ ctx.mode_ordering = _mode_ordering
+ ctx.finufftkwargs = finufftkwargs
+
+ nufft_func = get_nufft_func(points.shape[0], 2, points.device.type)
+
+ finufft_out = nufft_func(
+ *points,
+ targets,
+ isign=_i_sign,
+ **finufftkwargs,
+ )
+
+ return finufft_out
+
+ @staticmethod
+ def backward( # type: ignore[override]
+ ctx: Any, grad_output: torch.Tensor
+ ) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], None, None, None,]:
+ """
+ Implements derivatives wrt. each argument in the forward method.
+
+ Parameters
+ ----------
+ ctx : Any
+ Pytorch context object
+ grad_output : torch.Tensor
+ Backpass gradient output.
+
+ Returns
+ -------
+ Tuple[ Union[torch.Tensor, None], ...]
+ A tuple of derivatives wrt. each argument in the forward method
+ """
+ _i_sign = ctx.isign
+ _mode_ordering = ctx.mode_ordering
+ finufftkwargs = ctx.finufftkwargs
+
+ points, targets = ctx.saved_tensors
+ device = points.device
+
+ grad_points = grad_targets = None
+ ndim = points.shape[0]
+
+ if ctx.needs_input_grad[0]:
+ coord_ramps = coordinate_ramps(targets.shape, device=device)
+ ramped_targets = coord_ramps * targets[newaxis] * 1j * _i_sign
+ nufft_func = get_nufft_func(ndim, 2, points.device.type)
+
+ grad_points = nufft_func(
+ *points,
+ ramped_targets.squeeze(),
+ isign=_i_sign,
+ **finufftkwargs,
+ ).conj() # Why can't this be replaced with a flipped isign
+
+ grad_points = grad_points * grad_output
+ grad_points = torch.atleast_2d(grad_points.real)
+
+ if ctx.needs_input_grad[1]:
+ # wrt. targets
+ nufft_func = get_nufft_func(ndim, 1, points.device.type)
+
+ grad_targets = nufft_func(
+ *points,
+ grad_output,
+ targets.shape,
+ isign=-_i_sign,
+ **finufftkwargs,
+ )
+
+ if _mode_ordering:
+ grad_targets = torch.fft.ifftshift(grad_targets)
+
+ return (
+ grad_points,
+ grad_targets,
+ None,
+ None,
+ None,
+ )
+
+
+
+[docs]
+def finufft_type1(
+ points: torch.Tensor,
+ values: torch.Tensor,
+ output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
+ **finufftkwargs: Union[int, float],
+) -> torch.Tensor:
+ """
+ Evaluates the Type 1 (nonuniform-to-uniform) NUFFT on the inputs.
+
+ This is a wrapper around :func:`finufft.nufft1d1`, :func:`finufft.nufft2d1`, and
+ :func:`finufft.nufft3d1` on CPU, and :func:`cufinufft.nufft1d1`,
+ :func:`cufinufft.nufft2d1`, and :func:`cufinufft.nufft3d1` on GPU.
+
+ Parameters
+ ----------
+ points : torch.Tensor
+ DxN tensor of locations of the non-uniform points.
+ Points must lie in the range ``[-3pi, 3pi]``.
+ values : torch.Tensor
+ Length N complex-valued tensor of values at the non-uniform points
+ output_shape : int | tuple(int, ...)
+ Requested output shape of Fourier modes. Must be a tuple of length D or
+ an integer (1D only).
+ **finufftkwargs : int | float
+ Additional keyword arguments are forwarded to the underlying
+ FINUFFT functions. A few notable options are
+
+ - ``eps``: precision requested (default: ``1e-6``)
+ - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``)
+ - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``)
+
+ Returns
+ -------
+ torch.Tensor
+ Tensor with shape ``output_shape`` containing the Fourier
+ transform of the values.
+ """
+ res: torch.Tensor = FinufftType1.apply(points, values, output_shape, finufftkwargs)
+ return res
+
+
+
+
+[docs]
+def finufft_type2(
+ points: torch.Tensor,
+ targets: torch.Tensor,
+ **finufftkwargs: Union[int, float],
+) -> torch.Tensor:
+ """
+ Evaluates the Type 2 (uniform-to-nonuniform) NUFFT on the inputs.
+
+ This is a wrapper around :func:`finufft.nufft1d2`, :func:`finufft.nufft2d2`, and
+ :func:`finufft.nufft3d2` on CPU, and :func:`cufinufft.nufft1d2`,
+ :func:`cufinufft.nufft2d2`, and :func:`cufinufft.nufft3d2` on GPU.
+
+ Parameters
+ ----------
+ points : torch.Tensor
+ DxN tensor of locations of the non-uniform points.
+ Points must lie in the range ``[-3pi, 3pi]``.
+ targets : torch.Tensor
+ D-dimensional complex-valued tensor of Fourier modes to evaluate at the points
+ **finufftkwargs : int | float
+ Additional keyword arguments are forwarded to the underlying
+ FINUFFT functions. A few notable options are
+
+ - ``eps``: precision requested (default: ``1e-6``)
+ - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``)
+ - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``)
+
+ Returns
+ -------
+ torch.Tensor
+ A DxN tensor of values at the non-uniform points.
+ """
+ res: torch.Tensor = FinufftType2.apply(points, targets, finufftkwargs)
+ return res
+
+