diff --git a/examples/convolution_2d.py b/examples/convolution_2d.py index 655b3bb..7910fe2 100644 --- a/examples/convolution_2d.py +++ b/examples/convolution_2d.py @@ -1,7 +1,3 @@ -# Everything below is to be formatted in RST with intermittent code, so that it can be rendered in Sphinx-gallery -# This is a message to co-pilot, hopefully you can help me out here - - ############################################################## # Convolution in 2D example # ========================= @@ -30,7 +26,7 @@ def gaussian_function(x, y, sigma=1): # the [0, 2*pi] x [0, 2*pi] torus. So let's dimension it accordingly. shape = (128, 128) -sigma = 0.1 +sigma = 0.5 x = np.linspace(-np.pi, np.pi, shape[0], endpoint=False) y = np.linspace(-np.pi, np.pi, shape[1], endpoint=False) @@ -86,3 +82,35 @@ def gaussian_function(x, y, sigma=1): ax.imshow(convolved_points.real) ax.scatter(points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c='r') +############################################################## +# We see that the convolution has smeared out the point cloud. +# After a small coordinate change, we can also plot the original points +# on the same plot as the convolved points. + + +############################################################## +# Next, we invert the Fourier transform on the same points as +# our original point cloud. We will then compare this to direct evaluation +# of the kernel on all pairwise difference vectors between the points. + +convolved_at_points = pytorch_finufft.functional.finufft_type2.apply( + torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel, + None, {'isign': 1} + ).real / np.prod(shape) + +fig, ax = plt.subplots() +ax.imshow(convolved_points.real) +ax.scatter(points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=10 * convolved_at_points, c='r') + +############################################################## +# To compute the convolution directly, we need to evaluate the kernel on all pairwise difference vectors between the points. + +pairwise_diffs = points[:, np.newaxis] - points[:, :, np.newaxis] +kernel_diff_evals = gaussian_function(*pairwise_diffs, sigma=sigma) +convolved_by_hand = kernel_diff_evals.sum(1) + +fig, ax = plt.subplots() +ax.plot(convolved_at_points.numpy(), convolved_by_hand, ".") +ax.plot([1, 3], [1, 3]) + +print(f"Relative difference between fourier convolution and direct convolution {torch.norm(convolved_at_points - convolved_by_hand) / np.linalg.norm(convolved_by_hand)}")