Skip to content

Commit

Permalink
added command line tool for generating model file
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Apr 3, 2024
1 parent f4c35bb commit 4b5a377
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ dependencies = [

[project.scripts]
train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving"
model_to_tensor = "cryo_sbi.untils.command_line_tools:cl_model_to_tensor"
28 changes: 28 additions & 0 deletions src/cryo_sbi/utils/command_line_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse
from cryo_sbi.utils.generate_models import models_to_tensor


def cl_models_to_tensor():
cl_parser = argparse.ArgumentParser(
description="Convert models to tensor for cryoSBI"
epilog="pdb-files: The name for the pdbs must contain a {} to be replaced by the index of the pdb file. The index starts at 0. \
For example protein_{}.pdb. trr-files: For .trr files you must provide a topology file.")
cl_parser.add_argument(
"--model_files", action="store", type=str, required=True
)
cl_parser.add_argument(
"--output_file", action="store", type=str, required=True
)
cl_parser.add_argument(
"--n_pdbs", action="store", type=int, required=False, default=None
)
cl_parser.add_argument(
"--top_file", action="store", type=str, required=False, default=None
)
args = cl_parser.parse_args()
models_to_tensor(
model_files=args.model_files,
output_file=args.output_file,
n_pdbs=args.n_pdbs,
top_file=args.top_file
)
41 changes: 41 additions & 0 deletions src/cryo_sbi/utils/generate_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union
import MDAnalysis as mda
from MDAnalysis.analysis import align
import torch
Expand Down Expand Up @@ -125,3 +126,43 @@ def traj_parser(top_file: str, traj_file: str, output_file: str) -> None:
raise ValueError("Model file format not supported. Please use .pt.")

return


def models_to_tensor(
model_files,
output_file,
n_pdbs: Union[int, None] = None,
top_file: Union[str, None] = None,
):
"""
Converts different model files to a torch tensor.
Parameters
----------
model_files : list
A list of model files to convert to a torch tensor.
output_file : str
The path to the output file. Must be a .pt file.
n_models : int
The number of models to convert to a torch tensor. Just needed for models in pdb files.
top_file : str
The path to the topology file. Just needed for models in trr files.
Returns
-------
None
"""
assert output_file.endswith("pt"), "The output file must be a .pt file."
if model_files.endswith("trr"):
assert top_file is not None, "Please provide a topology file."
assert n_pdbs is None, "The number of pdb files is not needed for trr files."
traj_parser(top_file, model_files, output_file)
elif model_files.endswith("pdb"):
assert n_pdbs is not None, "Please provide the number of pdb files."
assert top_file is None, "The topology file is not needed for pdb files."
pdb_parser(model_files, n_pdbs, output_file)


0 comments on commit 4b5a377

Please sign in to comment.