From 5f7167432b307c26d9fe9afb4dc3ef93ca3bb4e5 Mon Sep 17 00:00:00 2001 From: Michael Eickenberg Date: Wed, 18 Oct 2023 11:27:50 -0400 Subject: [PATCH] COSM formatting --- examples/convolution_2d.py | 154 +++++++++++++++++++++++-------------- 1 file changed, 97 insertions(+), 57 deletions(-) diff --git a/examples/convolution_2d.py b/examples/convolution_2d.py index df7cd6b..fece9ad 100644 --- a/examples/convolution_2d.py +++ b/examples/convolution_2d.py @@ -1,29 +1,31 @@ -############################################################## +####################################################################################### # Convolution in 2D example # ========================= -############################################################## +####################################################################################### # Import packages # --------------- # # First, we import the packages we need for this example. +import matplotlib.pyplot as plt import numpy as np import torch + import pytorch_finufft -import matplotlib.pyplot as plt -############################################################## +####################################################################################### # Let's create a Gaussian convolutional filter as a function of x,y + def gaussian_function(x, y, sigma=1): return np.exp(-(x**2 + y**2) / (2 * sigma**2)) -############################################################## -# Let's visualize this filter kernel. We will be using it to convolve with points living on -# the [0, 2*pi] x [0, 2*pi] torus. So let's dimension it accordingly. +####################################################################################### +# Let's visualize this filter kernel. We will be using it to convolve with points +# living on the [0, 2*pi] x [0, 2*pi] torus. So let's dimension it accordingly. shape = (128, 128) sigma = 0.5 @@ -35,7 +37,7 @@ def gaussian_function(x, y, sigma=1): fig, ax = plt.subplots() ax.imshow(gaussian_kernel) -############################################################## +####################################################################################### # In order for the kernel to not shift the signal, we need to place its mass at 0 # To do this, we ifftshift the kernel @@ -45,7 +47,7 @@ def gaussian_function(x, y, sigma=1): ax.imshow(shifted_gaussian_kernel) -############################################################## +####################################################################################### # Now let's create a point cloud on the torus that we can convolve with our filter N = 20 @@ -54,56 +56,80 @@ def gaussian_function(x, y, sigma=1): fig, ax = plt.subplots() ax.set_xlim(0, 2 * np.pi) ax.set_ylim(0, 2 * np.pi) -ax.set_aspect('equal') +ax.set_aspect("equal") ax.scatter(points[0], points[1], s=1) -############################################################## +####################################################################################### # Now we can convolve the point cloud with the filter kernel. # To do this, we Fourier-transform both the point cloud and the filter kernel, # multiply them together, and then inverse Fourier-transform the result. # First we need to convert all data to torch tensors -fourier_shifted_gaussian_kernel = torch.fft.fft2(torch.from_numpy(shifted_gaussian_kernel)) -fourier_points = pytorch_finufft.functional.finufft_type1.apply(torch.from_numpy(points), torch.ones(points.shape[1], dtype=torch.complex128), shape) +fourier_shifted_gaussian_kernel = torch.fft.fft2( + torch.from_numpy(shifted_gaussian_kernel) +) +fourier_points = pytorch_finufft.functional.finufft_type1.apply( + torch.from_numpy(points), torch.ones(points.shape[1], dtype=torch.complex128), shape +) fig, axs = plt.subplots(1, 3) axs[0].imshow(fourier_shifted_gaussian_kernel.real) axs[1].imshow(fourier_points.real, vmin=-10, vmax=10) -axs[2].imshow((fourier_points * fourier_shifted_gaussian_kernel / fourier_shifted_gaussian_kernel[0, 0]).real, vmin=-10, vmax=10) - - -############################################################## -# We now have two possibilities: Invert the Fourier transform on a grid, or on a point cloud. -# We'll first invert the Fourier transform on a grid in order to be able to visualize the effect of the convolution. +axs[2].imshow( + ( + fourier_points + * fourier_shifted_gaussian_kernel + / fourier_shifted_gaussian_kernel[0, 0] + ).real, + vmin=-10, + vmax=10, +) + + +####################################################################################### +# We now have two possibilities: Invert the Fourier transform on a grid, or on a point +# cloud. We'll first invert the Fourier transform on a grid in order to be able to +# visualize the effect of the convolution. convolved_points = torch.fft.ifft2(fourier_points * fourier_shifted_gaussian_kernel) fig, ax = plt.subplots() ax.imshow(convolved_points.real) -ax.scatter(points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c='r') +ax.scatter( + points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c="r" +) -############################################################## +####################################################################################### # We see that the convolution has smeared out the point cloud. # After a small coordinate change, we can also plot the original points # on the same plot as the convolved points. -############################################################## -# Next, we invert the Fourier transform on the same points as +####################################################################################### +# Next, we invert the Fourier transform on the same points as # our original point cloud. We will then compare this to direct evaluation # of the kernel on all pairwise difference vectors between the points. convolved_at_points = pytorch_finufft.functional.finufft_type2.apply( - torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel, - None, {'isign': 1} - ).real / np.prod(shape) + torch.from_numpy(points), + fourier_points * fourier_shifted_gaussian_kernel, + None, + {"isign": 1}, +).real / np.prod(shape) fig, ax = plt.subplots() ax.imshow(convolved_points.real) -ax.scatter(points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=10 * convolved_at_points, c='r') - -############################################################## -# To compute the convolution directly, we need to evaluate the kernel on all pairwise difference vectors between the points. +ax.scatter( + points[1] / 2 / np.pi * shape[0], + points[0] / 2 / np.pi * shape[1], + s=10 * convolved_at_points, + c="r", +) + +####################################################################################### +# To compute the convolution directly, we need to evaluate the kernel on all pairwise +# difference vectors between the points. Note the points that will be off the diagonal. +# These will be due to the periodic boundary conditions of the convolution. pairwise_diffs = points[:, np.newaxis] - points[:, :, np.newaxis] kernel_diff_evals = gaussian_function(*pairwise_diffs, sigma=sigma) @@ -113,41 +139,56 @@ def gaussian_function(x, y, sigma=1): ax.plot(convolved_at_points.numpy(), convolved_by_hand, ".") ax.plot([1, 3], [1, 3]) -print(f"Relative difference between fourier convolution and direct convolution {torch.norm(convolved_at_points - convolved_by_hand) / np.linalg.norm(convolved_by_hand)}") +relative_difference = torch.norm( + convolved_at_points - convolved_by_hand +) / np.linalg.norm(convolved_by_hand) +print( + "Relative difference between fourier convolution and direct convolution " + f"{relative_difference}" +) +####################################################################################### +# Now let's see if we can learn the convolution kernel from the input and output point +# clouds. To this end, let's first make a pytorch object that can compute a kernel +# convolution on a point cloud. -############################################################## -# Now let's see if we can learn the convolution kernel from the input and output point clouds. -# To this end, let's first make a pytorch object that can compute a kernel convolution -# on a point cloud. class FourierPointConvolution(torch.nn.Module): - def __init__(self, fourier_kernel_shape): super().__init__() self.fourier_kernel_shape = fourier_kernel_shape - self.build() + self.build() def build(self): - self.register_parameter("fourier_kernel", - torch.nn.Parameter(torch.randn(self.fourier_kernel_shape, dtype=torch.complex128))) + self.register_parameter( + "fourier_kernel", + torch.nn.Parameter( + torch.randn(self.fourier_kernel_shape, dtype=torch.complex128) + ), + ) # ^ think about whether we need to scale this init in some better way - + def forward(self, points, values): - fourier_transformed_input = pytorch_finufft.functional.finufft_type1.apply(points, values, self.fourier_kernel_shape) + fourier_transformed_input = pytorch_finufft.functional.finufft_type1.apply( + points, values, self.fourier_kernel_shape + ) fourier_convolved = fourier_transformed_input * self.fourier_kernel - convolved = pytorch_finufft.functional.finufft_type2.apply(points, fourier_convolved, None, {'isign': 1}).real / np.prod(self.fourier_kernel_shape) + convolved = pytorch_finufft.functional.finufft_type2.apply( + points, fourier_convolved, None, {"isign": 1} + ).real / np.prod(self.fourier_kernel_shape) return convolved - -############################################################## -# Now we can use this object in a pytorch training loop to learn the kernel from the input and output point clouds. -# We will use the mean squared error as a loss function. + +####################################################################################### +# Now we can use this object in a pytorch training loop to learn the kernel from the +# input and output point clouds. We will use the mean squared error as a loss function. fourier_point_convolution = FourierPointConvolution(shape) -optimizer = torch.optim.AdamW(fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001) +optimizer = torch.optim.AdamW( + fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001 +) ones = torch.ones(points.shape[1], dtype=torch.complex128) @@ -156,14 +197,16 @@ def forward(self, points, values): points = np.random.rand(2, N) * 2 * np.pi torch_points = torch.from_numpy(points) - fourier_points = pytorch_finufft.functional.finufft_type1.apply(torch.from_numpy(points), - ones, shape) + fourier_points = pytorch_finufft.functional.finufft_type1.apply( + torch.from_numpy(points), ones, shape + ) convolved_at_points = pytorch_finufft.functional.finufft_type2.apply( - torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel, - None, {'isign': 1} - ).real / np.prod(shape) - - + torch.from_numpy(points), + fourier_points * fourier_shifted_gaussian_kernel, + None, + {"isign": 1}, + ).real / np.prod(shape) + optimizer.zero_grad() convolved = fourier_point_convolution(torch_points, ones) loss = torch.nn.functional.mse_loss(convolved, convolved_at_points) @@ -173,6 +216,3 @@ def forward(self, points, values): if i % 100 == 0: print(f"Loss: {loss.item()}") - - -