Skip to content

Commit

Permalink
WIP added learning of the kernel. It's learning something, but I'm no…
Browse files Browse the repository at this point in the history
…t entirely sure what - it seems to be learning autocorrelation-related thing rather than the kernel itself, but that may be somehow inherent to the problem setu[
  • Loading branch information
eickenberg committed Oct 18, 2023
1 parent 4b3175b commit b0563f0
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions examples/convolution_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,64 @@ def gaussian_function(x, y, sigma=1):
ax.plot([1, 3], [1, 3])

print(f"Relative difference between fourier convolution and direct convolution {torch.norm(convolved_at_points - convolved_by_hand) / np.linalg.norm(convolved_by_hand)}")



##############################################################
# Now let's see if we can learn the convolution kernel from the input and output point clouds.
# To this end, let's first make a pytorch object that can compute a kernel convolution
# on a point cloud.

class FourierPointConvolution(torch.nn.Module):

def __init__(self, fourier_kernel_shape):
super().__init__()
self.fourier_kernel_shape = fourier_kernel_shape

self.build()

def build(self):
self.register_parameter("fourier_kernel",
torch.nn.Parameter(torch.randn(self.fourier_kernel_shape, dtype=torch.complex128)))
# ^ think about whether we need to scale this init in some better way

def forward(self, points, values):
fourier_transformed_input = pytorch_finufft.functional.finufft_type1.apply(points, values, self.fourier_kernel_shape)
fourier_convolved = fourier_transformed_input * self.fourier_kernel
convolved = pytorch_finufft.functional.finufft_type2.apply(points, fourier_convolved, None, {'isign': 1}).real / np.prod(self.fourier_kernel_shape)
return convolved


##############################################################
# Now we can use this object in a pytorch training loop to learn the kernel from the input and output point clouds.
# We will use the mean squared error as a loss function.

fourier_point_convolution = FourierPointConvolution(shape)
optimizer = torch.optim.AdamW(fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001)

ones = torch.ones(points.shape[1], dtype=torch.complex128)

losses = []
for i in range(25000):
points = np.random.rand(2, N) * 2 * np.pi
torch_points = torch.from_numpy(points)

fourier_points = pytorch_finufft.functional.finufft_type1.apply(torch.from_numpy(points),
ones, shape)
convolved_at_points = pytorch_finufft.functional.finufft_type2.apply(
torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel,
None, {'isign': 1}
).real / np.prod(shape)


optimizer.zero_grad()
convolved = fourier_point_convolution(torch_points, ones)
loss = torch.nn.functional.mse_loss(convolved, convolved_at_points)
losses.append(loss.item())
loss.backward()
optimizer.step()

if i % 100 == 0:
print(f"Loss: {loss.item()}")


0 comments on commit b0563f0

Please sign in to comment.