From 9898e7364d57e4f76b5f067546e894c812eeb58a Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Sat, 14 Dec 2024 20:35:03 +0100 Subject: [PATCH] modified training and model scripts to run with 2d confromational changes --- .../inference/models/estimator_models.py | 4 +- src/cryo_sbi/inference/priors.py | 4 +- src/cryo_sbi/utils/estimator_utils.py | 2 +- src/cryo_sbi/utils/generate_models.py | 68 ------------------- .../wpa_simulator/cryo_em_simulator.py | 6 +- 5 files changed, 8 insertions(+), 76 deletions(-) diff --git a/src/cryo_sbi/inference/models/estimator_models.py b/src/cryo_sbi/inference/models/estimator_models.py index a1ff0fe..d1405dc 100644 --- a/src/cryo_sbi/inference/models/estimator_models.py +++ b/src/cryo_sbi/inference/models/estimator_models.py @@ -17,7 +17,7 @@ class Standardize(nn.Module): """ # Code adapted from :https://github.com/mackelab/sbi/blob/main/sbi/utils/sbiutils.py - def __init__(self, mean: float, std: float) -> None: + def __init__(self, mean: list[float], std: list[float]) -> None: super(Standardize, self).__init__() mean, std = map(torch.as_tensor, (mean, std)) self.mean = mean @@ -94,7 +94,7 @@ def __init__( super().__init__() self.npe = NPE( - 1, + 2, output_embedding_dim, transforms=num_transforms, build=flow, diff --git a/src/cryo_sbi/inference/priors.py b/src/cryo_sbi/inference/priors.py index f8fbb29..d1e195a 100644 --- a/src/cryo_sbi/inference/priors.py +++ b/src/cryo_sbi/inference/priors.py @@ -111,8 +111,8 @@ def get_image_priors( ) index_prior = zuko.distributions.BoxUniform( - lower=torch.tensor([0], dtype=torch.float32, device=device), - upper=torch.tensor([max_index], dtype=torch.float32, device=device), + lower=torch.tensor([0, 0], dtype=torch.float32, device=device), + upper=torch.tensor([max_index, max_index], dtype=torch.float32, device=device), ) quaternion_prior = QuaternionPrior(device) if ( diff --git a/src/cryo_sbi/utils/estimator_utils.py b/src/cryo_sbi/utils/estimator_utils.py index 44523be..ba9adbb 100644 --- a/src/cryo_sbi/utils/estimator_utils.py +++ b/src/cryo_sbi/utils/estimator_utils.py @@ -62,7 +62,7 @@ def sample_posterior( device: str = "cpu", ) -> torch.Tensor: """ - Samples from the posterior distribution + Samples from the 2D posterior distribution Args: estimator (torch.nn.Module): The posterior to use for sampling. diff --git a/src/cryo_sbi/utils/generate_models.py b/src/cryo_sbi/utils/generate_models.py index 0f86f69..f2c7406 100644 --- a/src/cryo_sbi/utils/generate_models.py +++ b/src/cryo_sbi/utils/generate_models.py @@ -60,74 +60,6 @@ def pdb_parser(file_formatter, n_pdbs, output_file, start_index=1, **kwargs): return -def traj_parser_(top_file: str, traj_file: str) -> torch.tensor: - """ - Parses a traj file and returns a coarsed grained atomic model of the protein. - The atomic model is a Mx3xN array, where M is the number of frames in the trajectory, - and N is the number of residues in the protein. The first three rows in axis 1 are the x, y, z coordinates of the alpha carbons. - - Parameters - ---------- - top_file : str - The path to the traj file. - - Returns - ------- - atomic_model : torch.tensor - The coarse grained atomic model of the protein. - """ - - ref = mda.Universe(top_file) - ref.atoms.translate(-ref.atoms.center_of_mass()) - - mobile = mda.Universe(top_file, traj_file) - align.AlignTraj(mobile, ref, select="name CA", in_memory=True).run() - - atomic_models = torch.zeros( - (mobile.trajectory.n_frames, 3, mobile.select_atoms("name CA").n_atoms) - ) - - for i in range(mobile.trajectory.n_frames): - mobile.trajectory[i] - - atomic_models[i, 0:3, :] = torch.from_numpy( - mobile.select_atoms("name CA").positions.T - ) - - return atomic_models - - -def traj_parser(top_file: str, traj_file: str, output_file: str) -> None: - """ - Parses a traj file and returns an atomic model of the protein. The atomic model is a Mx5xN array, where M is the number of frames in the trajectory, and N is the number of atoms in the protein. The first three rows in axis 1 are the x, y, z coordinates of the atoms. The fourth row is the atomic number of the atoms. The fifth row is the variance of the atoms before the resolution is applied. - - Parameters - ---------- - top_file : str - The path to the topology file. - traj_file : str - The path to the trajectory file. - output_file : str - The path to the output file. Must be a .pt file. - mode : str - The mode of the atomic model. Either "resid" or "all-atom". Resid mode returns a coarse grained atomic model of the protein. All atom mode returns an all atom atomic model of the protein. - - Returns - ------- - None - """ - - atomic_models = traj_parser_(top_file, traj_file) - - if output_file.endswith("pt"): - torch.save(atomic_models, output_file) - - else: - raise ValueError("Model file format not supported. Please use .pt.") - - return - - def models_to_tensor( model_files, output_file, diff --git a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py index efcc7c9..35ee642 100644 --- a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py +++ b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py @@ -43,7 +43,7 @@ def cryo_em_simulator( Returns: torch.Tensor: A tensor of the simulated cryo-EM image. """ - models_selected = models[index.round().long().flatten()] + models_selected = models[index[:, 0].round().long(), index[:, 1].round().long()] image = project_density( models_selected, quaternion, @@ -116,8 +116,8 @@ def _load_models(self) -> None: self._models = models - assert self._models.ndim == 3, "Models are not of shape (models, 3, atoms)." - assert self._models.shape[1] == 3, "Models are not of shape (models, 3, atoms)." + assert self._models.ndim == 4, "Models are not of shape (models_dim1, models_dim2, 3, atoms)." + assert self._models.shape[2] == 3, "Models are not of shape (models, 3, atoms)." @property def max_index(self) -> int: