Skip to content

Commit

Permalink
add atom selection argument to pdb parser
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jul 31, 2024
1 parent a60c974 commit 0c31d76
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/cryo_sbi/utils/generate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch


def pdb_parser_(fname: str) -> torch.tensor:
def pdb_parser_(fname: str, atom_selection: str = "name CA") -> torch.tensor:
"""
Parses a pdb file and returns a coarsed grained atomic model of the protein.
The atomic model is a 5xN array, where N is the number of residues in the protein.
Expand All @@ -24,12 +24,12 @@ def pdb_parser_(fname: str) -> torch.tensor:
univ = mda.Universe(fname)
univ.atoms.translate(-univ.atoms.center_of_mass())

model = torch.from_numpy(univ.select_atoms("name CA").positions.T)
model = torch.from_numpy(univ.select_atoms(atom_selection).positions.T)

return model


def pdb_parser(file_formatter, n_pdbs, output_file, start_index=1):
def pdb_parser(file_formatter, n_pdbs, output_file, start_index=1, **kwargs):
"""
Parses multiple pdb files and returns an coarsed grained model of the protein. The atomic model is a 5xN array, where N is the number of atoms or residues in the protein. The first three rows are the x, y, z coordinates of the atoms or residues. The fourth row is the atomic number of the atoms or the density of the residues. The fifth row is the variance of the atoms or residues, which is the resolution of the cryo-EM map divided by pi squared.
Expand All @@ -45,11 +45,11 @@ def pdb_parser(file_formatter, n_pdbs, output_file, start_index=1):
The mode of the atomic model. Either "resid" or "all atom". Resid mode returns a coarse grained atomic model of the protein. All atom mode returns an all atom atomic model of the protein.
"""

models = pdb_parser_(file_formatter.format(start_index))
models = pdb_parser_(file_formatter.format(start_index), **kwargs)
models = torch.zeros((n_pdbs, *models.shape))

for i in range(0, n_pdbs):
models[i] = pdb_parser_(file_formatter.format(start_index + i))
models[i] = pdb_parser_(file_formatter.format(start_index + i), **kwargs)

if output_file.endswith("pt"):
torch.save(models, output_file)
Expand Down

0 comments on commit 0c31d76

Please sign in to comment.