From 4b5a377946ace9103bfb9bc79fa5f559c8aa5770 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Wed, 3 Apr 2024 11:01:16 +0200 Subject: [PATCH] added command line tool for generating model file --- pyproject.toml | 1 + src/cryo_sbi/utils/command_line_tools.py | 28 ++++++++++++++++ src/cryo_sbi/utils/generate_models.py | 41 ++++++++++++++++++++++++ 3 files changed, 70 insertions(+) create mode 100644 src/cryo_sbi/utils/command_line_tools.py diff --git a/pyproject.toml b/pyproject.toml index bfae890..8da6f78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,4 @@ dependencies = [ [project.scripts] train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving" +model_to_tensor = "cryo_sbi.untils.command_line_tools:cl_model_to_tensor" diff --git a/src/cryo_sbi/utils/command_line_tools.py b/src/cryo_sbi/utils/command_line_tools.py new file mode 100644 index 0000000..bbb3c06 --- /dev/null +++ b/src/cryo_sbi/utils/command_line_tools.py @@ -0,0 +1,28 @@ +import argparse +from cryo_sbi.utils.generate_models import models_to_tensor + + +def cl_models_to_tensor(): + cl_parser = argparse.ArgumentParser( + description="Convert models to tensor for cryoSBI" + epilog="pdb-files: The name for the pdbs must contain a {} to be replaced by the index of the pdb file. The index starts at 0. \ + For example protein_{}.pdb. trr-files: For .trr files you must provide a topology file.") + cl_parser.add_argument( + "--model_files", action="store", type=str, required=True + ) + cl_parser.add_argument( + "--output_file", action="store", type=str, required=True + ) + cl_parser.add_argument( + "--n_pdbs", action="store", type=int, required=False, default=None + ) + cl_parser.add_argument( + "--top_file", action="store", type=str, required=False, default=None + ) + args = cl_parser.parse_args() + models_to_tensor( + model_files=args.model_files, + output_file=args.output_file, + n_pdbs=args.n_pdbs, + top_file=args.top_file + ) \ No newline at end of file diff --git a/src/cryo_sbi/utils/generate_models.py b/src/cryo_sbi/utils/generate_models.py index 221e084..1067053 100644 --- a/src/cryo_sbi/utils/generate_models.py +++ b/src/cryo_sbi/utils/generate_models.py @@ -1,3 +1,4 @@ +from typing import Union import MDAnalysis as mda from MDAnalysis.analysis import align import torch @@ -125,3 +126,43 @@ def traj_parser(top_file: str, traj_file: str, output_file: str) -> None: raise ValueError("Model file format not supported. Please use .pt.") return + + +def models_to_tensor( + model_files, + output_file, + n_pdbs: Union[int, None] = None, + top_file: Union[str, None] = None, + ): + """ + Converts different model files to a torch tensor. + + Parameters + ---------- + model_files : list + A list of model files to convert to a torch tensor. + + output_file : str + The path to the output file. Must be a .pt file. + + n_models : int + The number of models to convert to a torch tensor. Just needed for models in pdb files. + + top_file : str + The path to the topology file. Just needed for models in trr files. + + Returns + ------- + None + """ + assert output_file.endswith("pt"), "The output file must be a .pt file." + if model_files.endswith("trr"): + assert top_file is not None, "Please provide a topology file." + assert n_pdbs is None, "The number of pdb files is not needed for trr files." + traj_parser(top_file, model_files, output_file) + elif model_files.endswith("pdb"): + assert n_pdbs is not None, "Please provide the number of pdb files." + assert top_file is None, "The topology file is not needed for pdb files." + pdb_parser(model_files, n_pdbs, output_file) + +