Skip to content

Commit

Permalink
modified training and model scripts to run with 2d confromational cha…
Browse files Browse the repository at this point in the history
…nges
  • Loading branch information
Dingel321 committed Dec 14, 2024
1 parent d145cf2 commit 9898e73
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 76 deletions.
4 changes: 2 additions & 2 deletions src/cryo_sbi/inference/models/estimator_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Standardize(nn.Module):
"""

# Code adapted from :https://github.com/mackelab/sbi/blob/main/sbi/utils/sbiutils.py
def __init__(self, mean: float, std: float) -> None:
def __init__(self, mean: list[float], std: list[float]) -> None:
super(Standardize, self).__init__()
mean, std = map(torch.as_tensor, (mean, std))
self.mean = mean
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
super().__init__()

self.npe = NPE(
1,
2,
output_embedding_dim,
transforms=num_transforms,
build=flow,
Expand Down
4 changes: 2 additions & 2 deletions src/cryo_sbi/inference/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def get_image_priors(
)

index_prior = zuko.distributions.BoxUniform(
lower=torch.tensor([0], dtype=torch.float32, device=device),
upper=torch.tensor([max_index], dtype=torch.float32, device=device),
lower=torch.tensor([0, 0], dtype=torch.float32, device=device),
upper=torch.tensor([max_index, max_index], dtype=torch.float32, device=device),
)
quaternion_prior = QuaternionPrior(device)
if (
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def sample_posterior(
device: str = "cpu",
) -> torch.Tensor:
"""
Samples from the posterior distribution
Samples from the 2D posterior distribution
Args:
estimator (torch.nn.Module): The posterior to use for sampling.
Expand Down
68 changes: 0 additions & 68 deletions src/cryo_sbi/utils/generate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,74 +60,6 @@ def pdb_parser(file_formatter, n_pdbs, output_file, start_index=1, **kwargs):
return


def traj_parser_(top_file: str, traj_file: str) -> torch.tensor:
"""
Parses a traj file and returns a coarsed grained atomic model of the protein.
The atomic model is a Mx3xN array, where M is the number of frames in the trajectory,
and N is the number of residues in the protein. The first three rows in axis 1 are the x, y, z coordinates of the alpha carbons.
Parameters
----------
top_file : str
The path to the traj file.
Returns
-------
atomic_model : torch.tensor
The coarse grained atomic model of the protein.
"""

ref = mda.Universe(top_file)
ref.atoms.translate(-ref.atoms.center_of_mass())

mobile = mda.Universe(top_file, traj_file)
align.AlignTraj(mobile, ref, select="name CA", in_memory=True).run()

atomic_models = torch.zeros(
(mobile.trajectory.n_frames, 3, mobile.select_atoms("name CA").n_atoms)
)

for i in range(mobile.trajectory.n_frames):
mobile.trajectory[i]

atomic_models[i, 0:3, :] = torch.from_numpy(
mobile.select_atoms("name CA").positions.T
)

return atomic_models


def traj_parser(top_file: str, traj_file: str, output_file: str) -> None:
"""
Parses a traj file and returns an atomic model of the protein. The atomic model is a Mx5xN array, where M is the number of frames in the trajectory, and N is the number of atoms in the protein. The first three rows in axis 1 are the x, y, z coordinates of the atoms. The fourth row is the atomic number of the atoms. The fifth row is the variance of the atoms before the resolution is applied.
Parameters
----------
top_file : str
The path to the topology file.
traj_file : str
The path to the trajectory file.
output_file : str
The path to the output file. Must be a .pt file.
mode : str
The mode of the atomic model. Either "resid" or "all-atom". Resid mode returns a coarse grained atomic model of the protein. All atom mode returns an all atom atomic model of the protein.
Returns
-------
None
"""

atomic_models = traj_parser_(top_file, traj_file)

if output_file.endswith("pt"):
torch.save(atomic_models, output_file)

else:
raise ValueError("Model file format not supported. Please use .pt.")

return


def models_to_tensor(
model_files,
output_file,
Expand Down
6 changes: 3 additions & 3 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def cryo_em_simulator(
Returns:
torch.Tensor: A tensor of the simulated cryo-EM image.
"""
models_selected = models[index.round().long().flatten()]
models_selected = models[index[:, 0].round().long(), index[:, 1].round().long()]
image = project_density(
models_selected,
quaternion,
Expand Down Expand Up @@ -116,8 +116,8 @@ def _load_models(self) -> None:

self._models = models

assert self._models.ndim == 3, "Models are not of shape (models, 3, atoms)."
assert self._models.shape[1] == 3, "Models are not of shape (models, 3, atoms)."
assert self._models.ndim == 4, "Models are not of shape (models_dim1, models_dim2, 3, atoms)."
assert self._models.shape[2] == 3, "Models are not of shape (models, 3, atoms)."

@property
def max_index(self) -> int:
Expand Down

0 comments on commit 9898e73

Please sign in to comment.