Skip to content

Commit

Permalink
COSM formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
eickenberg committed Oct 18, 2023
1 parent e8143d5 commit 5f71674
Showing 1 changed file with 97 additions and 57 deletions.
154 changes: 97 additions & 57 deletions examples/convolution_2d.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
##############################################################
#######################################################################################
# Convolution in 2D example
# =========================


##############################################################
#######################################################################################
# Import packages
# ---------------
#
# First, we import the packages we need for this example.

import matplotlib.pyplot as plt
import numpy as np
import torch

import pytorch_finufft
import matplotlib.pyplot as plt

##############################################################
#######################################################################################
# Let's create a Gaussian convolutional filter as a function of x,y


def gaussian_function(x, y, sigma=1):
return np.exp(-(x**2 + y**2) / (2 * sigma**2))


##############################################################
# Let's visualize this filter kernel. We will be using it to convolve with points living on
# the [0, 2*pi] x [0, 2*pi] torus. So let's dimension it accordingly.
#######################################################################################
# Let's visualize this filter kernel. We will be using it to convolve with points
# living on the [0, 2*pi] x [0, 2*pi] torus. So let's dimension it accordingly.

shape = (128, 128)
sigma = 0.5
Expand All @@ -35,7 +37,7 @@ def gaussian_function(x, y, sigma=1):
fig, ax = plt.subplots()
ax.imshow(gaussian_kernel)

##############################################################
#######################################################################################
# In order for the kernel to not shift the signal, we need to place its mass at 0
# To do this, we ifftshift the kernel

Expand All @@ -45,7 +47,7 @@ def gaussian_function(x, y, sigma=1):
ax.imshow(shifted_gaussian_kernel)


##############################################################
#######################################################################################
# Now let's create a point cloud on the torus that we can convolve with our filter

N = 20
Expand All @@ -54,56 +56,80 @@ def gaussian_function(x, y, sigma=1):
fig, ax = plt.subplots()
ax.set_xlim(0, 2 * np.pi)
ax.set_ylim(0, 2 * np.pi)
ax.set_aspect('equal')
ax.set_aspect("equal")
ax.scatter(points[0], points[1], s=1)

##############################################################
#######################################################################################
# Now we can convolve the point cloud with the filter kernel.
# To do this, we Fourier-transform both the point cloud and the filter kernel,
# multiply them together, and then inverse Fourier-transform the result.
# First we need to convert all data to torch tensors

fourier_shifted_gaussian_kernel = torch.fft.fft2(torch.from_numpy(shifted_gaussian_kernel))
fourier_points = pytorch_finufft.functional.finufft_type1.apply(torch.from_numpy(points), torch.ones(points.shape[1], dtype=torch.complex128), shape)
fourier_shifted_gaussian_kernel = torch.fft.fft2(
torch.from_numpy(shifted_gaussian_kernel)
)
fourier_points = pytorch_finufft.functional.finufft_type1.apply(
torch.from_numpy(points), torch.ones(points.shape[1], dtype=torch.complex128), shape
)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(fourier_shifted_gaussian_kernel.real)
axs[1].imshow(fourier_points.real, vmin=-10, vmax=10)
axs[2].imshow((fourier_points * fourier_shifted_gaussian_kernel / fourier_shifted_gaussian_kernel[0, 0]).real, vmin=-10, vmax=10)


##############################################################
# We now have two possibilities: Invert the Fourier transform on a grid, or on a point cloud.
# We'll first invert the Fourier transform on a grid in order to be able to visualize the effect of the convolution.
axs[2].imshow(
(
fourier_points
* fourier_shifted_gaussian_kernel
/ fourier_shifted_gaussian_kernel[0, 0]
).real,
vmin=-10,
vmax=10,
)


#######################################################################################
# We now have two possibilities: Invert the Fourier transform on a grid, or on a point
# cloud. We'll first invert the Fourier transform on a grid in order to be able to
# visualize the effect of the convolution.

convolved_points = torch.fft.ifft2(fourier_points * fourier_shifted_gaussian_kernel)

fig, ax = plt.subplots()
ax.imshow(convolved_points.real)
ax.scatter(points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c='r')
ax.scatter(
points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c="r"
)

##############################################################
#######################################################################################
# We see that the convolution has smeared out the point cloud.
# After a small coordinate change, we can also plot the original points
# on the same plot as the convolved points.


##############################################################
# Next, we invert the Fourier transform on the same points as
#######################################################################################
# Next, we invert the Fourier transform on the same points as
# our original point cloud. We will then compare this to direct evaluation
# of the kernel on all pairwise difference vectors between the points.

convolved_at_points = pytorch_finufft.functional.finufft_type2.apply(
torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel,
None, {'isign': 1}
).real / np.prod(shape)
torch.from_numpy(points),
fourier_points * fourier_shifted_gaussian_kernel,
None,
{"isign": 1},
).real / np.prod(shape)

fig, ax = plt.subplots()
ax.imshow(convolved_points.real)
ax.scatter(points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=10 * convolved_at_points, c='r')

##############################################################
# To compute the convolution directly, we need to evaluate the kernel on all pairwise difference vectors between the points.
ax.scatter(
points[1] / 2 / np.pi * shape[0],
points[0] / 2 / np.pi * shape[1],
s=10 * convolved_at_points,
c="r",
)

#######################################################################################
# To compute the convolution directly, we need to evaluate the kernel on all pairwise
# difference vectors between the points. Note the points that will be off the diagonal.
# These will be due to the periodic boundary conditions of the convolution.

pairwise_diffs = points[:, np.newaxis] - points[:, :, np.newaxis]
kernel_diff_evals = gaussian_function(*pairwise_diffs, sigma=sigma)
Expand All @@ -113,41 +139,56 @@ def gaussian_function(x, y, sigma=1):
ax.plot(convolved_at_points.numpy(), convolved_by_hand, ".")
ax.plot([1, 3], [1, 3])

print(f"Relative difference between fourier convolution and direct convolution {torch.norm(convolved_at_points - convolved_by_hand) / np.linalg.norm(convolved_by_hand)}")
relative_difference = torch.norm(
convolved_at_points - convolved_by_hand
) / np.linalg.norm(convolved_by_hand)
print(
"Relative difference between fourier convolution and direct convolution "
f"{relative_difference}"
)


#######################################################################################
# Now let's see if we can learn the convolution kernel from the input and output point
# clouds. To this end, let's first make a pytorch object that can compute a kernel
# convolution on a point cloud.

##############################################################
# Now let's see if we can learn the convolution kernel from the input and output point clouds.
# To this end, let's first make a pytorch object that can compute a kernel convolution
# on a point cloud.

class FourierPointConvolution(torch.nn.Module):

def __init__(self, fourier_kernel_shape):
super().__init__()
self.fourier_kernel_shape = fourier_kernel_shape

self.build()
self.build()

def build(self):
self.register_parameter("fourier_kernel",
torch.nn.Parameter(torch.randn(self.fourier_kernel_shape, dtype=torch.complex128)))
self.register_parameter(
"fourier_kernel",
torch.nn.Parameter(
torch.randn(self.fourier_kernel_shape, dtype=torch.complex128)
),
)
# ^ think about whether we need to scale this init in some better way

def forward(self, points, values):
fourier_transformed_input = pytorch_finufft.functional.finufft_type1.apply(points, values, self.fourier_kernel_shape)
fourier_transformed_input = pytorch_finufft.functional.finufft_type1.apply(
points, values, self.fourier_kernel_shape
)
fourier_convolved = fourier_transformed_input * self.fourier_kernel
convolved = pytorch_finufft.functional.finufft_type2.apply(points, fourier_convolved, None, {'isign': 1}).real / np.prod(self.fourier_kernel_shape)
convolved = pytorch_finufft.functional.finufft_type2.apply(
points, fourier_convolved, None, {"isign": 1}
).real / np.prod(self.fourier_kernel_shape)
return convolved


##############################################################
# Now we can use this object in a pytorch training loop to learn the kernel from the input and output point clouds.
# We will use the mean squared error as a loss function.

#######################################################################################
# Now we can use this object in a pytorch training loop to learn the kernel from the
# input and output point clouds. We will use the mean squared error as a loss function.

fourier_point_convolution = FourierPointConvolution(shape)
optimizer = torch.optim.AdamW(fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001)
optimizer = torch.optim.AdamW(
fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001
)

ones = torch.ones(points.shape[1], dtype=torch.complex128)

Expand All @@ -156,14 +197,16 @@ def forward(self, points, values):
points = np.random.rand(2, N) * 2 * np.pi
torch_points = torch.from_numpy(points)

fourier_points = pytorch_finufft.functional.finufft_type1.apply(torch.from_numpy(points),
ones, shape)
fourier_points = pytorch_finufft.functional.finufft_type1.apply(
torch.from_numpy(points), ones, shape
)
convolved_at_points = pytorch_finufft.functional.finufft_type2.apply(
torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel,
None, {'isign': 1}
).real / np.prod(shape)


torch.from_numpy(points),
fourier_points * fourier_shifted_gaussian_kernel,
None,
{"isign": 1},
).real / np.prod(shape)

optimizer.zero_grad()
convolved = fourier_point_convolution(torch_points, ones)
loss = torch.nn.functional.mse_loss(convolved, convolved_at_points)
Expand All @@ -173,6 +216,3 @@ def forward(self, points, values):

if i % 100 == 0:
print(f"Loss: {loss.item()}")



0 comments on commit 5f71674

Please sign in to comment.