diff --git a/examples/convolution_2d.py b/examples/convolution_2d.py index 7910fe2..38ee033 100644 --- a/examples/convolution_2d.py +++ b/examples/convolution_2d.py @@ -114,3 +114,64 @@ def gaussian_function(x, y, sigma=1): ax.plot([1, 3], [1, 3]) print(f"Relative difference between fourier convolution and direct convolution {torch.norm(convolved_at_points - convolved_by_hand) / np.linalg.norm(convolved_by_hand)}") + + + +############################################################## +# Now let's see if we can learn the convolution kernel from the input and output point clouds. +# To this end, let's first make a pytorch object that can compute a kernel convolution +# on a point cloud. + +class FourierPointConvolution(torch.nn.Module): + + def __init__(self, fourier_kernel_shape): + super().__init__() + self.fourier_kernel_shape = fourier_kernel_shape + + self.build() + + def build(self): + self.register_parameter("fourier_kernel", + torch.nn.Parameter(torch.randn(self.fourier_kernel_shape, dtype=torch.complex128))) + # ^ think about whether we need to scale this init in some better way + + def forward(self, points, values): + fourier_transformed_input = pytorch_finufft.functional.finufft_type1.apply(points, values, self.fourier_kernel_shape) + fourier_convolved = fourier_transformed_input * self.fourier_kernel + convolved = pytorch_finufft.functional.finufft_type2.apply(points, fourier_convolved, None, {'isign': 1}).real / np.prod(self.fourier_kernel_shape) + return convolved + + +############################################################## +# Now we can use this object in a pytorch training loop to learn the kernel from the input and output point clouds. +# We will use the mean squared error as a loss function. + +fourier_point_convolution = FourierPointConvolution(shape) +optimizer = torch.optim.AdamW(fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001) + +ones = torch.ones(points.shape[1], dtype=torch.complex128) + +losses = [] +for i in range(25000): + points = np.random.rand(2, N) * 2 * np.pi + torch_points = torch.from_numpy(points) + + fourier_points = pytorch_finufft.functional.finufft_type1.apply(torch.from_numpy(points), + ones, shape) + convolved_at_points = pytorch_finufft.functional.finufft_type2.apply( + torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel, + None, {'isign': 1} + ).real / np.prod(shape) + + + optimizer.zero_grad() + convolved = fourier_point_convolution(torch_points, ones) + loss = torch.nn.functional.mse_loss(convolved, convolved_at_points) + losses.append(loss.item()) + loss.backward() + optimizer.step() + + if i % 100 == 0: + print(f"Loss: {loss.item()}") + +