Skip to content

Commit

Permalink
added some docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Apr 17, 2023
1 parent 994ec58 commit ac20143
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
19 changes: 19 additions & 0 deletions src/cryo_sbi/inference/NPE_train_without_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ def npe_train_no_saving(
model_state_dict=None,
n_workers=1,
):
"""Train NPE model without saving the model
Args:
image_config (str): path to image config file
train_config (str): path to train config file
epochs (int): number of epochs
estimator_file (str): path to estimator file
loss_file (str): path to loss file
train_from_checkpoint (bool, optional): train from checkpoint. Defaults to False.
model_state_dict (str, optional): path to model state dict. Defaults to None.
n_workers (int, optional): number of workers. Defaults to 1.
Raises:
Warning: No model state dict specified! --model_state_dict is empty
Returns:
None
"""

cryo_simulator = CryoEmSimulator(image_config)

train_config = json.load(open(train_config))
Expand Down
9 changes: 8 additions & 1 deletion src/cryo_sbi/inference/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@


def get_uniform_prior_1d(max_index):
"""Return uniform prior in 1d from 0 to 19"""
"""Return uniform prior in 1d from 0 to 19
Args:
max_index (int): max index of the 1d prior
Returns:
zuko.distributions.BoxUniform: prior
"""

assert isinstance(max_index, int), "max_index is no INT"

Expand Down
40 changes: 39 additions & 1 deletion src/cryo_sbi/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@


def sample_posterior(estimator, images, num_samples, batch_size=100, device="cpu"):
"""Samples from the posterior distribution
Args:
estimator (torch.nn.Module): The posterior to use for sampling.
images (torch.Tensor): The images used to condition the posterio.
num_samples (int): The number of samples to draw
batch_size (int, optional): The batch size to use. Defaults to 100.
device (str, optional): The device to use. Defaults to "cpu".
Returns:
torch.Tensor: The samples
"""

theta_samples = []

if images.shape[0] > batch_size and batch_size > 0:
Expand All @@ -23,6 +36,18 @@ def sample_posterior(estimator, images, num_samples, batch_size=100, device="cpu


def compute_latent_repr(estimator, images, batch_size=100, device="cpu"):
"""Computes the latent representation of images.
Args:
estimator (torch.nn.Module): Posterior model for which to compute the latent representation.
images (torch.Tensor): The images to compute the latent representation for.
batch_size (int, optional): The batch size to use. Defaults to 100.
device (str, optional): The device to use. Defaults to "cpu".
Returns:
torch.Tensor: The latent representation of the images.
"""

latent_space_samples = []

if images.shape[0] > batch_size and batch_size > 0:
Expand All @@ -42,9 +67,22 @@ def compute_latent_repr(estimator, images, batch_size=100, device="cpu"):


def load_estimator(config_file_path, estimator_path, device="cpu"):
"""Loads a trained estimator.
Args:
config_file_path (str): Path to the config file used to train the estimator.
estimator_path (str): Path to the estimator.
device (str, optional): The device to use. Defaults to "cpu".
Returns:
torch.nn.Module: The loaded estimator.
"""

train_config = json.load(open(config_file_path))
estimator = build_models.build_npe_flow_model(train_config)
estimator.load_state_dict(torch.load(estimator_path))
estimator.load_state_dict(
torch.load(estimator_path, map_location=torch.device(device))
)
estimator.to(device)
estimator.eval()

Expand Down
13 changes: 13 additions & 0 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@


class CryoEmSimulator:
"""Simulator for cryo-EM images.
Args:
config_fname (str): Path to the configuration file.
Attributes:
config (dict): Configuration parameters.
models (np.ndarray): The models to use for image generation.
rot_mode (str): The rotation mode to use. Can be "random", "list" or None.
quaternions (np.ndarray): The quaternions to use for image generation.
add_noise (bool): function which adds noise to images. Defaults to Gaussian noise.
"""

def __init__(self, config_fname):
self._load_params(config_fname)
self._load_models()
Expand Down

0 comments on commit ac20143

Please sign in to comment.