diff --git a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py index d2bdd32..b097a1e 100644 --- a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py +++ b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py @@ -25,8 +25,9 @@ def cryo_em_simulator( snr, num_pixels, pixel_size, - noise=True, ctf=True, + noise=True, + normalize=True, ): """ Simulates a bacth of cryo-electron microscopy (cryo-EM) images of a set of given coars-grained models. @@ -60,7 +61,8 @@ def cryo_em_simulator( image = apply_ctf(image, defocus, b_factor, amp, pixel_size) if noise: image = add_noise(image, snr) - image = gaussian_normalize_image(image) + if normalize: + image = gaussian_normalize_image(image) return image @@ -148,7 +150,7 @@ def max_index(self) -> int: int: Maximum index of the model file. """ return len(self._models) - 1 - + def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None, noise=True, ctf=True): """ Simulate cryo-EM images using the specified models and prior distributions. @@ -201,34 +203,73 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No return images.cpu(), parameters else: return images.cpu() + + + def simulate_with_micrograph_noise(self, num_sim, micrographs, indices=None, return_parameters=False, parameters=None, batch_size=None, ctf=True, snr=0.0001): + """ + Simulate cryo-EM images using the specified models and prior distributions. + + Args: + num_sim (int): The number of images to simulate. + indices (torch.Tensor, optional): The indices of the images to simulate. If None, all images are simulated. + return_parameters (bool, optional): Whether to return the sampled parameters used for simulation. + batch_size (int, optional): The batch size to use for simulation. If None, all images are simulated in a single batch. - def simulate_with_micrograph_noise(self, num_sim, micrographs, indices=None, return_parameters=False, batch_size=None, ctf=True, snr=0.0001): - self._init_micrograph_loader(micrographs, self._config["N_PIXELS"], num_noise_samples=num_sim) - images_and_maybe_params = self.simulate( - num_sim=num_sim, - indices=indices, - return_parameters=return_parameters, - batch_size=batch_size, + Returns: + torch.Tensor or tuple: The simulated images as a tensor of shape (num_sim, num_pixels, num_pixels), + and optionally the sampled parameters as a tuple of tensors. + """ + if parameters is None: + parameters = self._priors.sample((num_sim,)) + + indices = parameters[0] if indices is None else indices + if indices is not None: + assert isinstance( + indices, torch.Tensor + ), "Indices are not a torch.tensor, converting to torch.tensor." + assert ( + indices.dtype == torch.float32 + ), "Indices are not a torch.float32, converting to torch.float32." + assert ( + indices.ndim == 2 + ), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)." + parameters[0] = indices + + images = [] + if batch_size is None: + batch_size = num_sim + + self._init_micrograph_loader(micrographs, self._config["N_PIXELS"], num_noise_samples=batch_size) + + for i in range(0, num_sim, batch_size): + batch_indices = indices[i : i + batch_size] + batch_parameters = [param[i : i + batch_size] for param in parameters[1:]] + batch_images = cryo_em_simulator( + self._models, + batch_indices, + *batch_parameters, + self._num_pixels, + self._pixel_size, noise=False, - ctf=ctf + ctf=ctf, + normalize=False ) - if return_parameters: - images, parameters = images_and_maybe_params - else: - images = images_and_maybe_params - print("finished simulating images, drawing noise samples...") + noise_power = get_snr(batch_images, batch_parameters[-1]) + noise_samples = [] for noise_sample in self._micrograph_loader: noise_samples.append(noise_sample) - noise_samples = torch.cat(noise_samples, dim=0) - - print("finished drawing noise samples, adding noise to images...") - noise_samples = noise_samples / snr + noise_samples = torch.cat(noise_samples, dim=0).to(self._device) + noise_samples = noise_samples * noise_power + batch_images = batch_images + noise_samples + batch_images = gaussian_normalize_image(batch_images) - images = images + noise_samples - images = gaussian_normalize_image(images) + images.append(batch_images.cpu()) + images = torch.cat(images, dim=0) - return images - + if return_parameters: + return images.cpu(), parameters + else: + return images.cpu()