Skip to content

Commit

Permalink
added eval_log_prob function
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Oct 20, 2023
1 parent 9c5b59b commit 6225a18
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 3 deletions.
33 changes: 32 additions & 1 deletion src/cryo_sbi/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,40 @@ def evaluate_log_prob(
estimator: torch.nn.Module,
images: torch.Tensor,
theta: torch.Tensor,
batch_size: int = 0,
device: str = "cpu",
) -> torch.Tensor:
pass # TODO implement function to evaluate log prob

# batching images if necessary
if images.shape[0] > batch_size and batch_size > 0:
images = torch.split(images, split_size_or_sections=batch_size, dim=0)
else:
batch_size = images.shape[0]
images = [images]

# theta dimensions [num_eval, num_images, 1]
if theta.ndim == 3:
num_eval = theta.shape[0]
num_images = images.shape[0]
assert theta.shape == torch.Size([num_eval, num_images, 1])

elif theta.ndim == 2:
raise IndexError("theta must have 3 dimensions [num_eval, num_images, 1]")

elif theta.ndim == 1:
theta = theta.reshape(-1, 1, 1).repeat(1, batch_size, 1)

log_probs = []
for image_batch in images:
posterior = estimator.flow(image_batch.to(device))
log_probs.append(
posterior.log_prob(
estimator.standardize(theta.to(device))
)
)

log_probs = torch.cat(log_probs, dim=1)
return log_probs


@torch.no_grad()
Expand Down
22 changes: 21 additions & 1 deletion src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ def cryo_em_simulator(
num_pixels,
pixel_size,
):
"""
Simulates a bacth of cryo-electron microscopy (cryo-EM) images of a set of given coars-grained models.
Args:
models (torch.Tensor): A tensor of coars grained models (num_models, 3, num_beads).
index (torch.Tensor): A tensor of indices to select the models to simulate.
quaternion (torch.Tensor): A tensor of quaternions to rotate the models.
sigma (float): The standard deviation of the Gaussian kernel used to project the density.
shift (torch.Tensor): A tensor of shifts to apply to the models.
defocus (float): The defocus value of the contrast transfer function (CTF).
b_factor (float): The B-factor of the CTF.
amp (float): The amplitude contrast of the CTF.
snr (float): The signal-to-noise ratio of the simulated image.
num_pixels (int): The number of pixels in the simulated image.
pixel_size (float): The size of each pixel in the simulated image.
Returns:
torch.Tensor: A tensor of the simulated cryo-EM image.
"""
models_selected = models[index.round().long().flatten()]
image = project_density(
models_selected,
Expand Down Expand Up @@ -121,7 +140,8 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No
batch_size (int, optional): The batch size to use for simulation. If None, all images are simulated in a single batch.
Returns:
torch.Tensor or tuple: The simulated images as a tensor of shape (num_sim, num_pixels, num_pixels), and optionally the sampled parameters as a tuple of tensors.
torch.Tensor or tuple: The simulated images as a tensor of shape (num_sim, num_pixels, num_pixels),
and optionally the sampled parameters as a tuple of tensors.
"""

parameters = self._priors.sample((num_sim,))
Expand Down
20 changes: 19 additions & 1 deletion tests/test_estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from cryo_sbi.utils.estimator_utils import (
sample_posterior,
compute_latent_repr,
load_estimator,
evaluate_log_prob,
load_estimator
)


Expand Down Expand Up @@ -57,6 +58,23 @@ def test_latent_extraction(train_params, num_images, batch_size):
), f"Failed with: num_images: {num_images}, batch_size:{batch_size}"


@pytest.mark.parametrize(
("num_images", "num_eval", "batch_size"),
[(1, 1, 1), (2, 10, 2), (5, 1000, 5), (100, 2, 100)],
)
def test_logprob_eval(train_params, num_images, num_eval, batch_size):
estimator = build_models.build_npe_flow_model(train_params)
estimator.eval()
images = torch.randn((num_images, 128, 128))
theta = torch.linspace(0, 25, num_eval)
samples = evaluate_log_prob(
estimator, images, theta, batch_size=batch_size
)
assert samples.shape == torch.Size(
[num_eval, num_images]
), f"Failed with: num_images: {num_images}, num_eval:{num_eval}, batch_size:{batch_size}"


def test_load_estimator(train_params, train_config_path):
estimator = build_models.build_npe_flow_model(train_params)
torch.save(estimator.state_dict(), "tests/config_files/test_estimator.estimator")
Expand Down

0 comments on commit 6225a18

Please sign in to comment.