Skip to content

Commit

Permalink
added clear images into mmd loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 20, 2024
1 parent 244dd4a commit f487bce
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
10 changes: 4 additions & 6 deletions src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def npe_train_no_saving(
saving_frequency: int = 20,
simulation_batch_size: int = 1024,
gamma: float = 1.0,
experimental_particles: Union[str, None] = None,
) -> None:
"""
Train NPE model by simulating training data on the fly.
Expand Down Expand Up @@ -119,7 +118,6 @@ def npe_train_no_saving(
)

loss = NPERobustStatsLoss(estimator, gamma)
experimental_particles = torch.load(experimental_particles, map_location=device)

optimizer = optim.AdamW(
estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001
Expand All @@ -143,7 +141,7 @@ def npe_train_no_saving(
amp,
snr,
) = parameters
images = cryo_em_simulator(
images, clear_images = cryo_em_simulator(
models,
indices.to(device, non_blocking=True),
quaternions.to(device, non_blocking=True),
Expand All @@ -156,17 +154,17 @@ def npe_train_no_saving(
num_pixels,
pixel_size,
)
for _indices, _images in zip(
for _indices, _images, _clear_images in zip(
indices.split(train_config["BATCH_SIZE"]),
images.split(train_config["BATCH_SIZE"]),
clear_images.split(train_config["BATCH_SIZE"]),
):
random_indices = torch.randperm(experimental_particles.size(0))[:train_config["BATCH_SIZE"]]
losses.append(
step(
loss(
_indices.to(device, non_blocking=True),
_images.to(device, non_blocking=True),
experimental_particles[random_indices].to(device, non_blocking=True),
_clear_images.to(device, non_blocking=True),
)
)
)
Expand Down
16 changes: 10 additions & 6 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def cryo_em_simulator(
num_pixels,
pixel_size,
)
image = apply_ctf(image, defocus, b_factor, amp, pixel_size)
image = add_noise(image, snr)
clear_image = apply_ctf(image, defocus, b_factor, amp, pixel_size)
image = add_noise(clear_image, snr)
image = gaussian_normalize_image(image)
return image
clear_image = gaussian_normalize_image(clear_image)
return image, clear_image


class CryoEmSimulator:
Expand Down Expand Up @@ -159,23 +160,26 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No
parameters[0] = indices

images = []
clear_images = []
if batch_size is None:
batch_size = num_sim
for i in range(0, num_sim, batch_size):
batch_indices = indices[i : i + batch_size]
batch_parameters = [param[i : i + batch_size] for param in parameters[1:]]
batch_images = cryo_em_simulator(
batch_images, batch_clear_images = cryo_em_simulator(
self._models,
batch_indices,
*batch_parameters,
self._num_pixels,
self._pixel_size,
)
images.append(batch_images.cpu())
clear_images.append(batch_clear_images.cpu())

images = torch.cat(images, dim=0)
clear_images = torch.cat(clear_images, dim=0)

if return_parameters:
return images.cpu(), parameters
return images.cpu(), clear_images.cpu(), parameters
else:
return images.cpu()
return images.cpu(), clear_images.cpu()

0 comments on commit f487bce

Please sign in to comment.