Skip to content

Commit

Permalink
DOC update example to work with new API
Browse files Browse the repository at this point in the history
  • Loading branch information
eickenberg committed Oct 18, 2023
1 parent 5f71674 commit e794155
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions examples/convolution_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def gaussian_function(x, y, sigma=1):
fourier_shifted_gaussian_kernel = torch.fft.fft2(
torch.from_numpy(shifted_gaussian_kernel)
)
fourier_points = pytorch_finufft.functional.finufft_type1.apply(
fourier_points = pytorch_finufft.functional.finufft_type1(
torch.from_numpy(points), torch.ones(points.shape[1], dtype=torch.complex128), shape
)

Expand Down Expand Up @@ -110,11 +110,10 @@ def gaussian_function(x, y, sigma=1):
# our original point cloud. We will then compare this to direct evaluation
# of the kernel on all pairwise difference vectors between the points.

convolved_at_points = pytorch_finufft.functional.finufft_type2.apply(
convolved_at_points = pytorch_finufft.functional.finufft_type2(
torch.from_numpy(points),
fourier_points * fourier_shifted_gaussian_kernel,
None,
{"isign": 1},
isign=1,
).real / np.prod(shape)

fig, ax = plt.subplots()
Expand Down Expand Up @@ -171,12 +170,14 @@ def build(self):
# ^ think about whether we need to scale this init in some better way

def forward(self, points, values):
fourier_transformed_input = pytorch_finufft.functional.finufft_type1.apply(
fourier_transformed_input = pytorch_finufft.functional.finufft_type1(
points, values, self.fourier_kernel_shape
)
fourier_convolved = fourier_transformed_input * self.fourier_kernel
convolved = pytorch_finufft.functional.finufft_type2.apply(
points, fourier_convolved, None, {"isign": 1}
convolved = pytorch_finufft.functional.finufft_type2(
points,
fourier_convolved,
isign=1,
).real / np.prod(self.fourier_kernel_shape)
return convolved

Expand All @@ -193,20 +194,20 @@ def forward(self, points, values):
ones = torch.ones(points.shape[1], dtype=torch.complex128)

losses = []
for i in range(25000):
for i in range(10000):
# Make new set of points and compute forward model
points = np.random.rand(2, N) * 2 * np.pi
torch_points = torch.from_numpy(points)

fourier_points = pytorch_finufft.functional.finufft_type1.apply(
fourier_points = pytorch_finufft.functional.finufft_type1(
torch.from_numpy(points), ones, shape
)
convolved_at_points = pytorch_finufft.functional.finufft_type2.apply(
convolved_at_points = pytorch_finufft.functional.finufft_type2(
torch.from_numpy(points),
fourier_points * fourier_shifted_gaussian_kernel,
None,
{"isign": 1},
isign=1,
).real / np.prod(shape)

# Learning step
optimizer.zero_grad()
convolved = fourier_point_convolution(torch_points, ones)
loss = torch.nn.functional.mse_loss(convolved, convolved_at_points)
Expand All @@ -215,4 +216,19 @@ def forward(self, points, values):
optimizer.step()

if i % 100 == 0:
print(f"Loss: {loss.item()}")
print(f"Iteration {i:05d}, Loss: {loss.item():1.4f}")


fig, ax = plt.subplots()
ax.plot(losses)
ax.set_ylabel("Loss")
ax.set_xlabel("Iteration")
ax.set_yscale("log")

fig, ax = plt.subplots()
im = ax.imshow(
torch.real(torch.fft.fftshift(fourier_point_convolution.fourier_kernel.data))[
48:80, 48:80
]
)
fig.colorbar(im, ax=ax)

0 comments on commit e794155

Please sign in to comment.