Skip to content

Commit

Permalink
WIP full correspondence between fourier convolution and convolution b…
Browse files Browse the repository at this point in the history
…y hand
  • Loading branch information
eickenberg committed Oct 13, 2023
1 parent a1e624c commit 9790770
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions examples/convolution_2d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Everything below is to be formatted in RST with intermittent code, so that it can be rendered in Sphinx-gallery
# This is a message to co-pilot, hopefully you can help me out here


##############################################################
# Convolution in 2D example
# =========================
Expand Down Expand Up @@ -30,7 +26,7 @@ def gaussian_function(x, y, sigma=1):
# the [0, 2*pi] x [0, 2*pi] torus. So let's dimension it accordingly.

shape = (128, 128)
sigma = 0.1
sigma = 0.5
x = np.linspace(-np.pi, np.pi, shape[0], endpoint=False)
y = np.linspace(-np.pi, np.pi, shape[1], endpoint=False)

Expand Down Expand Up @@ -86,3 +82,35 @@ def gaussian_function(x, y, sigma=1):
ax.imshow(convolved_points.real)
ax.scatter(points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c='r')

##############################################################
# We see that the convolution has smeared out the point cloud.
# After a small coordinate change, we can also plot the original points
# on the same plot as the convolved points.


##############################################################
# Next, we invert the Fourier transform on the same points as
# our original point cloud. We will then compare this to direct evaluation
# of the kernel on all pairwise difference vectors between the points.

convolved_at_points = pytorch_finufft.functional.finufft_type2.apply(
torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel,
None, {'isign': 1}
).real / np.prod(shape)

fig, ax = plt.subplots()
ax.imshow(convolved_points.real)
ax.scatter(points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=10 * convolved_at_points, c='r')

##############################################################
# To compute the convolution directly, we need to evaluate the kernel on all pairwise difference vectors between the points.

pairwise_diffs = points[:, np.newaxis] - points[:, :, np.newaxis]
kernel_diff_evals = gaussian_function(*pairwise_diffs, sigma=sigma)
convolved_by_hand = kernel_diff_evals.sum(1)

fig, ax = plt.subplots()
ax.plot(convolved_at_points.numpy(), convolved_by_hand, ".")
ax.plot([1, 3], [1, 3])

print(f"Relative difference between fourier convolution and direct convolution {torch.norm(convolved_at_points - convolved_by_hand) / np.linalg.norm(convolved_by_hand)}")

0 comments on commit 9790770

Please sign in to comment.