Skip to content

Commit

Permalink
cleaned simulator code
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed May 15, 2023
1 parent c7229cb commit 1413aa4
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 25 deletions.
37 changes: 36 additions & 1 deletion src/cryo_sbi/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def __call__(self, images: torch.Tensor) -> torch.Tensor:
class MRCtoTensor:
"""
Convert an MRC file to a tensor.
Args:
image_path (str): Path to the MRC file.
Returns:
image (torch.Tensor): Image of shape (n_pixels, n_pixels).
"""

def __init__(self) -> None:
Expand All @@ -227,4 +233,33 @@ def __call__(self, image_path: str) -> torch.Tensor:
return torch.from_numpy(image)


# TODO: add whitening transform
class WhitenImage:
"""
Whiten an image by dividing by the noise PSD.
Args:
noise_psd (torch.Tensor): Noise PSD of shape (n_pixels, n_pixels).
Square root of the noise PSD is used to whiten the image.
Returns:
reconstructed (torch.Tensor): Whiten image.
"""

def __init__(self, noise_psd: torch.Tensor) -> None:
self.noise_psd = noise_psd

def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Whiten an image by dividing by the noise PSD.
Args:
image (torch.Tensor): Image of shape (n_pixels, n_pixels).
Returns:
reconstructed (torch.Tensor): Whiten image.
"""

fft_image = torch.fft.fft2(image)
fft_image = fft_image / torch.sqrt(self.noise_psd)
reconstructed = torch.fft.ifft2(fft_image).real
return reconstructed
71 changes: 64 additions & 7 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union, Callable
import torch
import numpy as np
import json
Expand Down Expand Up @@ -27,7 +28,7 @@ class CryoEmSimulator:
add_noise (bool): function which adds noise to images. Defaults to Gaussian noise.
"""

def __init__(self, config_fname):
def __init__(self, config_fname: str, add_noise: Callable = add_noise):
self._load_params(config_fname)
self._load_models()
self.rot_mode = None
Expand All @@ -36,12 +37,29 @@ def __init__(self, config_fname):
self._pad_width = int(np.ceil(self.config["N_PIXELS"] * 0.1)) + 1
self.add_noise = add_noise

def _load_params(self, config_fname):
def _load_params(self, config_fname: str) -> None:
"""
Loads the parameters from the config file into a dictionary.
Args:
config_fname (str): Path to the configuration file.
Returns:
None
"""

config = json.load(open(config_fname))
check_params(config)
self.config = config

def _load_models(self):
def _load_models(self) -> None:
"""
Loads the models from the model file specified in the config file.
Returns:
None
"""
if "hsp90" in self.config["MODEL_FILE"]:
self.models = np.load(self.config["MODEL_FILE"])[:, 0]

Expand All @@ -54,7 +72,13 @@ def _load_models(self):
)
print(self.config["MODEL_FILE"])

def _config_rotations(self):
def _config_rotations(self) -> None:
"""
Configures the rotation mode for the simulator.
Returns:
None
"""
if isinstance(self.config["ROTATIONS"], bool):
if self.config["ROTATIONS"]:
self.rot_mode = "random"
Expand All @@ -68,10 +92,30 @@ def _config_rotations(self):
), "Quaternion shape is not 4. Corrupted file?"

@property
def max_index(self):
def max_index(self) -> int:
"""
Returns the maximum index of the model file.
Returns:
int: Maximum index of the model file.
"""
return len(self.models) - 1

def _simulator_with_quat(self, index, quaternion, seed):
def _simulator_with_quat(
self, index: torch.Tensor, quaternion: np.ndarray, seed: Union[None, int] = None
) -> torch.Tensor:
"""
Simulates an image with a given quaternion.
Args:
index (torch.Tensor): Index of the model to use.
quaternion (np.ndarray): Quaternion to rotate structure.
seed (Union[None, int], optional): Seed for random number generator. Defaults to None.
Returns:
torch.Tensor: Simulated image.
"""

index = int(torch.round(index))

coord = np.copy(self.models[index])
Expand All @@ -98,7 +142,20 @@ def _simulator_with_quat(self, index, quaternion, seed):

return image.to(dtype=torch.float)

def simulator(self, index, seed=None):
def simulator(
self, index: torch.Tensor, seed: Union[None, int] = None
) -> torch.Tensor:
"""
Simulates an image with parameters specified in the config file.
Args:
index (torch.Tensor): Index of the model to use.
seed (Union[None, int], optional): Seed for random number generator. Defaults to None.
Returns:
torch.Tensor: Simulated image.
"""

if self.rot_mode == "random":
quat = gen_quat()
elif self.rot_mode == "list":
Expand Down
30 changes: 23 additions & 7 deletions src/cryo_sbi/wpa_simulator/noise.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Union
import numpy as np
import torch


def circular_mask(n_pixels, radius):
"""Creates a circular mask of radius RADIUS_MASK centered in the image
def circular_mask(n_pixels: int, radius: int) -> torch.Tensor:
"""
Creates a circular mask of radius RADIUS_MASK centered in the image
Args:
n_pixels (int): Number of pixels along image side.
Expand All @@ -20,8 +22,11 @@ def circular_mask(n_pixels, radius):
return mask


def add_noise(image, image_params, seed=None):
"""Adds noise to image
def add_noise(
image: torch.Tensor, image_params: dict, seed: Union[None, int] = None
) -> torch.Tensor:
"""
Adds noise to image
Args:
image (torch.Tensor): Image of shape (n_pixels, n_pixels).
Expand Down Expand Up @@ -57,7 +62,13 @@ def add_noise(image, image_params, seed=None):
return image_noise


def add_colored_noise(image, image_params, seed, noise_intensity=1, noise_scale=1.5):
def add_colored_noise(
image: torch.Tensor,
image_params: dict,
seed: int,
noise_intensity: float = 1,
noise_scale: float = 1.5,
):
"""Adds colored noise to image.
Args:
Expand Down Expand Up @@ -108,12 +119,17 @@ def add_colored_noise(image, image_params, seed, noise_intensity=1, noise_scale=
return image_noise + image


def add_shot_noise(image):
def add_shot_noise(image: torch.Tensor) -> torch.Tensor:
"""Adds shot noise to image"""
raise NotImplementedError


def add_gradient_snr(image, image_params, seed, delta_snr=0.5):
def add_gradient_snr(
image: torch.Tensor,
image_params: dict,
seed: Union[None, int] = None,
delta_snr: float = 0.5,
) -> torch.Tensor:
"""Adds gaussian noise with gradient along x.
Args:
Expand Down
5 changes: 3 additions & 2 deletions src/cryo_sbi/wpa_simulator/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import torch


def gaussian_normalize_image(image):
"""Normalize an image by subtracting the mean and dividing by the standard deviation.
def gaussian_normalize_image(image: torch.Tensor) -> torch.Tensor:
"""
Normalize an image by subtracting the mean and dividing by the standard deviation.
Args:
image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
Expand Down
6 changes: 4 additions & 2 deletions src/cryo_sbi/wpa_simulator/padding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import torch
from torch.nn.functional import pad
from torch.nn import ConstantPad2d


def pad_image(image, image_params):
"""Pads image with zeros to randomly crop image later.
def pad_image(image: torch.Tensor, image_params: dict) -> torch.Tensor:
"""
Pads image with zeros to randomly crop image later.
Args:
image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
Expand Down
22 changes: 19 additions & 3 deletions src/cryo_sbi/wpa_simulator/shift.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Union
import torch
import numpy as np


def apply_random_shift(padded_image, image_params, seed=None):
"""Applies random shift to image.
def apply_random_shift(
padded_image: torch.Tensor, image_params: dict, seed: Union[None, int] = None
) -> torch.Tensor:
"""
Applies random shift to image.
Args:
padded_image (torch.Tensor): Padded image of shape (n_pixels + 2 * pad_width, n_pixels + 2 * pad_width).
Expand Down Expand Up @@ -35,7 +39,19 @@ def apply_random_shift(padded_image, image_params, seed=None):
return shifted_image


def apply_no_shift(padded_image, image_params):
def apply_no_shift(padded_image: torch.Tensor, image_params: dict) -> torch.Tensor:
"""
Applies no shift to image, i.e. returns the image without padding.
Args:
padded_image (torch.Tensor): Padded image of shape (n_pixels + 2 * pad_width, n_pixels + 2 * pad_width).
With pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1.
image_params (dict): Dictionary containing image parameters.
Returns:
shifted_image (torch.Tensor): Shifted image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
"""

pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1

low_ind_x = pad_width
Expand Down
7 changes: 4 additions & 3 deletions src/cryo_sbi/wpa_simulator/validate_image_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
def check_params(config):
"""Checks if all necessary parameters are provided.
def check_params(config: dict) -> None:
"""
Checks if all necessary parameters are provided.
Args:
config (dict): Dictionary containing image parameters.
Expand Down Expand Up @@ -28,4 +29,4 @@ def check_params(config):
for key in needed_keys:
assert key in config.keys(), f"Please provide a value for {key}"

return
return None

0 comments on commit 1413aa4

Please sign in to comment.