From 6225a187be7dfce0ddddcb154c286a6e8b005ec3 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Fri, 20 Oct 2023 16:29:39 +0200 Subject: [PATCH] added eval_log_prob function --- src/cryo_sbi/utils/estimator_utils.py | 33 ++++++++++++++++++- .../wpa_simulator/cryo_em_simulator.py | 22 ++++++++++++- tests/test_estimator_utils.py | 20 ++++++++++- 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/src/cryo_sbi/utils/estimator_utils.py b/src/cryo_sbi/utils/estimator_utils.py index cc30f6d..ce565a1 100644 --- a/src/cryo_sbi/utils/estimator_utils.py +++ b/src/cryo_sbi/utils/estimator_utils.py @@ -8,9 +8,40 @@ def evaluate_log_prob( estimator: torch.nn.Module, images: torch.Tensor, theta: torch.Tensor, + batch_size: int = 0, device: str = "cpu", ) -> torch.Tensor: - pass # TODO implement function to evaluate log prob + + # batching images if necessary + if images.shape[0] > batch_size and batch_size > 0: + images = torch.split(images, split_size_or_sections=batch_size, dim=0) + else: + batch_size = images.shape[0] + images = [images] + + # theta dimensions [num_eval, num_images, 1] + if theta.ndim == 3: + num_eval = theta.shape[0] + num_images = images.shape[0] + assert theta.shape == torch.Size([num_eval, num_images, 1]) + + elif theta.ndim == 2: + raise IndexError("theta must have 3 dimensions [num_eval, num_images, 1]") + + elif theta.ndim == 1: + theta = theta.reshape(-1, 1, 1).repeat(1, batch_size, 1) + + log_probs = [] + for image_batch in images: + posterior = estimator.flow(image_batch.to(device)) + log_probs.append( + posterior.log_prob( + estimator.standardize(theta.to(device)) + ) + ) + + log_probs = torch.cat(log_probs, dim=1) + return log_probs @torch.no_grad() diff --git a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py index 51aabd6..742ce37 100644 --- a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py +++ b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py @@ -24,6 +24,25 @@ def cryo_em_simulator( num_pixels, pixel_size, ): + """ + Simulates a bacth of cryo-electron microscopy (cryo-EM) images of a set of given coars-grained models. + + Args: + models (torch.Tensor): A tensor of coars grained models (num_models, 3, num_beads). + index (torch.Tensor): A tensor of indices to select the models to simulate. + quaternion (torch.Tensor): A tensor of quaternions to rotate the models. + sigma (float): The standard deviation of the Gaussian kernel used to project the density. + shift (torch.Tensor): A tensor of shifts to apply to the models. + defocus (float): The defocus value of the contrast transfer function (CTF). + b_factor (float): The B-factor of the CTF. + amp (float): The amplitude contrast of the CTF. + snr (float): The signal-to-noise ratio of the simulated image. + num_pixels (int): The number of pixels in the simulated image. + pixel_size (float): The size of each pixel in the simulated image. + + Returns: + torch.Tensor: A tensor of the simulated cryo-EM image. + """ models_selected = models[index.round().long().flatten()] image = project_density( models_selected, @@ -121,7 +140,8 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No batch_size (int, optional): The batch size to use for simulation. If None, all images are simulated in a single batch. Returns: - torch.Tensor or tuple: The simulated images as a tensor of shape (num_sim, num_pixels, num_pixels), and optionally the sampled parameters as a tuple of tensors. + torch.Tensor or tuple: The simulated images as a tensor of shape (num_sim, num_pixels, num_pixels), + and optionally the sampled parameters as a tuple of tensors. """ parameters = self._priors.sample((num_sim,)) diff --git a/tests/test_estimator_utils.py b/tests/test_estimator_utils.py index 32b14bf..6cbb2f7 100644 --- a/tests/test_estimator_utils.py +++ b/tests/test_estimator_utils.py @@ -10,7 +10,8 @@ from cryo_sbi.utils.estimator_utils import ( sample_posterior, compute_latent_repr, - load_estimator, + evaluate_log_prob, + load_estimator ) @@ -57,6 +58,23 @@ def test_latent_extraction(train_params, num_images, batch_size): ), f"Failed with: num_images: {num_images}, batch_size:{batch_size}" +@pytest.mark.parametrize( + ("num_images", "num_eval", "batch_size"), + [(1, 1, 1), (2, 10, 2), (5, 1000, 5), (100, 2, 100)], +) +def test_logprob_eval(train_params, num_images, num_eval, batch_size): + estimator = build_models.build_npe_flow_model(train_params) + estimator.eval() + images = torch.randn((num_images, 128, 128)) + theta = torch.linspace(0, 25, num_eval) + samples = evaluate_log_prob( + estimator, images, theta, batch_size=batch_size + ) + assert samples.shape == torch.Size( + [num_eval, num_images] + ), f"Failed with: num_images: {num_images}, num_eval:{num_eval}, batch_size:{batch_size}" + + def test_load_estimator(train_params, train_config_path): estimator = build_models.build_npe_flow_model(train_params) torch.save(estimator.state_dict(), "tests/config_files/test_estimator.estimator")