Skip to content

Commit

Permalink
First draft of a ProcessorPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
pckroon committed Oct 17, 2023
1 parent 0459973 commit ea1e36f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 87 deletions.
170 changes: 83 additions & 87 deletions bin/martinize2
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import vermouth
import vermouth.forcefield
from vermouth.file_writer import deferred_open, DeferredFileWriter
from vermouth import DATA_PATH
from vermouth.processors.processor import ProcessorPipeline
from vermouth.dssp import dssp
from vermouth.dssp.dssp import (
AnnotateDSSP,
Expand Down Expand Up @@ -81,30 +82,8 @@ LOGGER = StyleAdapter(LOGGER)
VERSION = "martinize with vermouth {}".format(vermouth.__version__)


def read_system(path, ignore_resnames=(), ignh=None, modelidx=None):
"""
Read a system from a PDB or GRO file.
This function guesses the file type based on the file extension.
The resulting system does not have a force field and may not have edges.
"""
system = vermouth.System()
file_extension = path.suffix.upper()[1:] # We do not keep the dot
if file_extension in ["PDB", "ENT"]:
vermouth.PDBInput(
str(path), exclude=ignore_resnames, ignh=ignh, modelidx=modelidx
).run_system(system)
elif file_extension in ["GRO"]:
vermouth.GROInput(str(path), exclude=ignore_resnames, ignh=ignh).run_system(
system
)
else:
raise ValueError('Unknown file extension "{}".'.format(file_extension))
return system


def pdb_to_universal(
pipeline,
system,
delete_unknown=False,
force_field=None,
Expand All @@ -130,62 +109,64 @@ def pdb_to_universal(
canonicalized.force_field = force_field

LOGGER.info("Guessing the bonds.", type="step")
vermouth.MakeBonds(
allow_name=bonds_from_name, allow_dist=bonds_from_dist, fudge=bonds_fudge
).run_system(canonicalized)
vermouth.MergeNucleicStrands().run_system(canonicalized)
pipeline.add(
vermouth.MakeBonds(
allow_name=bonds_from_name, allow_dist=bonds_from_dist, fudge=bonds_fudge
)
)
pipeline.add(vermouth.MergeNucleicStrands())
if write_graph is not None:
vermouth.pdb.write_pdb(
canonicalized, str(write_graph), omit_charges=True, defer_writing=False
canonicalized, write_graph, omit_charges=True, defer_writing=False
)

LOGGER.debug("Annotating required mutations and modifications.", type="step")
vermouth.AnnotateMutMod(modifications, mutations).run_system(canonicalized)
pipeline.add(vermouth.AnnotateMutMod(modifications, mutations))
LOGGER.info("Repairing the graph.", type="step")
vermouth.RepairGraph(delete_unknown=delete_unknown, include_graph=False).run_system(
canonicalized
)
pipeline.add(vermouth.RepairGraph(delete_unknown=delete_unknown, include_graph=False))
if write_repair is not None:
vermouth.pdb.write_pdb(
canonicalized,
str(write_repair),
write_repair,
omit_charges=True,
nan_missing_pos=True,
defer_writing=False,
)
LOGGER.info("Dealing with modifications.", type="step")
vermouth.CanonicalizeModifications().run_system(canonicalized)
pipeline.add(vermouth.CanonicalizeModifications())
if write_canon is not None:
vermouth.pdb.write_pdb(
canonicalized,
str(write_canon),
write_canon,
omit_charges=True,
nan_missing_pos=True,
defer_writing=False,
)
vermouth.AttachMass(attribute="mass").run_system(canonicalized)
pipeline.add(vermouth.AttachMass(attribute="mass"))
return canonicalized


def martinize(system, mappings, to_ff, delete_unknown=False):
def martinize(pipeline, system, mappings, to_ff, delete_unknown=False):
"""
Convert a system from one force field to an other at lower resolution.
"""
LOGGER.info("Creating the graph at the target resolution.", type="step")
vermouth.DoMapping(
mappings=mappings,
to_ff=to_ff,
delete_unknown=delete_unknown,
attribute_keep=("cgsecstruct", "chain"),
attribute_must=("resname",),
attribute_stash=("resid",),
).run_system(system)
pipeline.add(
vermouth.DoMapping(
mappings=mappings,
to_ff=to_ff,
delete_unknown=delete_unknown,
attribute_keep=("cgsecstruct", "chain"),
attribute_must=("resname",),
attribute_stash=("resid",),
)
)
LOGGER.info("Averaging the coordinates.", type="step")
vermouth.DoAverageBead(ignore_missing_graphs=True).run_system(system)
pipeline.add(vermouth.DoAverageBead(ignore_missing_graphs=True))
LOGGER.info("Applying the links.", type="step")
vermouth.DoLinks().run_system(system)
pipeline.add(vermouth.DoLinks())
LOGGER.info("Placing the charge dummies.", type="step")
vermouth.LocateChargeDummies().run_system(system)
pipeline.add(vermouth.LocateChargeDummies())
return system


Expand Down Expand Up @@ -861,17 +842,20 @@ def entry():
if not any("nter" in resspec for resspec in resspecs):
args.modifications.append(["nter", "N-ter"])

# Reading the input structure.
# So far, we assume we only go from atomistic to martini. We want the
# input structure to be a clean universal system.
# For now at least, we silently delete molecules with unknown blocks.
system = read_system(
args.inpath,
ignore_resnames=ignore_res,
ignh=args.ignore_h,
modelidx=args.modelidx,
)
system = vermouth.System()
pipeline = ProcessorPipeline(name='martinize2')
to_universal = ProcessorPipeline(name='to universal')

file_extension = args.inpath.suffix.upper()[1:] # We do not keep the dot
if file_extension in ["PDB", "ENT"]:
to_universal.add(vermouth.PDBInput(args.inpath, exclude=ignore_res, ignh=args.ignore_h, modelidx=args.modelidx))
elif file_extension in ["GRO"]:
to_universal.add(vermouth.GROInput(args.inpath, exclude=ignore_res, ignh=args.ignore_h))
else:
raise ValueError('Unknown file extension "{}".'.format(file_extension))

system = pdb_to_universal(
to_universal,
system,
delete_unknown=True,
force_field=known_force_fields[from_ff],
Expand All @@ -884,22 +868,27 @@ def entry():
write_repair=args.write_repair,
write_canon=args.write_canon,
)
pipeline.add(to_universal)

LOGGER.info("Read input.", type="step")
for molecule in system.molecules:
LOGGER.debug("Read molecule {}.", molecule, type="step")

martinize_pipeline = ProcessorPipeline(name='martinize')

target_ff = known_force_fields[args.to_ff]
if args.dssp is not None:
AnnotateDSSP(executable=args.dssp, savedir=".").run_system(system)
AnnotateMartiniSecondaryStructures().run_system(system)
martinize_pipeline.add(AnnotateDSSP(executable=args.dssp, savedir="."))
martinize_pipeline.add(AnnotateMartiniSecondaryStructures())
elif args.ss is not None:
AnnotateResidues(
attribute="secstruct",
sequence=args.ss,
molecule_selector=selectors.is_protein,
).run_system(system)
AnnotateMartiniSecondaryStructures().run_system(system)
martinize_pipeline.add(
AnnotateResidues(
attribute="secstruct",
sequence=args.ss,
molecule_selector=selectors.is_protein,
)
)
martinize_pipeline.add(AnnotateMartiniSecondaryStructures())
elif args.collagen:
if not target_ff.has_feature("collagen"):
LOGGER.warning(
Expand All @@ -908,27 +897,29 @@ def entry():
target_ff.name,
type="missing-feature",
)
AnnotateResidues(
attribute="cgsecstruct",
sequence="F",
molecule_selector=selectors.is_protein,
).run_system(system)
martinize_pipeline.add(
AnnotateResidues(
attribute="cgsecstruct",
sequence="F",
molecule_selector=selectors.is_protein,
)
)
if args.extdih and not target_ff.has_feature("extdih"):
LOGGER.warning(
'The force field "{}" does not define dihedral '
"angles for extended regions of proteins (-extdih).",
target_ff.name,
type="missing-feature",
)
vermouth.SetMoleculeMeta(extdih=args.extdih).run_system(system)
martinize_pipeline.add(vermouth.SetMoleculeMeta(extdih=args.extdih))
if args.scfix and not target_ff.has_feature("scfix"):
LOGGER.warning(
'The force field "{}" does not define angle and '
"torsion for the side chain corrections (-scfix).",
target_ff.name,
type="missing-feature",
)
vermouth.SetMoleculeMeta(scfix=args.scfix).run_system(system)
martinize_pipeline.add(vermouth.SetMoleculeMeta(scfix=args.scfix))

ss_sequence = list(
itertools.chain(
Expand All @@ -941,18 +932,20 @@ def entry():
)

if args.cystein_bridge == "none":
vermouth.RemoveCysteinBridgeEdges().run_system(system)
martinize_pipeline.add(vermouth.RemoveCysteinBridgeEdges())
elif args.cystein_bridge != "auto":
vermouth.AddCysteinBridgesThreshold(args.cystein_bridge).run_system(system)
martinize_pipeline.add(vermouth.AddCysteinBridgesThreshold(args.cystein_bridge))

# Run martinize on the system.
system = martinize(
martinize_pipeline,
system,
mappings=known_mappings,
to_ff=known_force_fields[args.to_ff],
delete_unknown=True,
)

pipeline.add(martinize_pipeline)
postprocessing = ProcessorPipeline(name='post-processing')
# Apply position restraints if required.
if args.posres != "none":
LOGGER.info("Applying position restraints.", type="step")
Expand All @@ -961,7 +954,7 @@ def entry():
"backbone": selectors.select_backbone,
}
node_selector = node_selectors[args.posres]
vermouth.ApplyPosres(node_selector, args.posres_fc).run_system(system)
postprocessing.add(vermouth.ApplyPosres(node_selector, args.posres_fc))

if args.govs_includes:
# The way Virtual Site GoMartini works has to be in sync with
Expand All @@ -976,16 +969,16 @@ def entry():
LOGGER.info(
"The output topology will require files generated by " '"create_goVirt.py".'
)
vermouth.MergeAllMolecules().run_system(system)
vermouth.SetMoleculeMeta(moltype=args.govs_moltype).run_system(system)
vermouth.GoVirtIncludes().run_system(system)
postprocessing.add(vermouth.MergeAllMolecules())
postprocessing.add(vermouth.SetMoleculeMeta(moltype=args.govs_moltype))
postprocessing.add(vermouth.GoVirtIncludes())
defines = ("GO_VIRT",)
else:
# Merge chains if required.
if args.merge_chains:
for chain_set in args.merge_chains:
vermouth.MergeChains(chain_set).run_system(system)
vermouth.NameMolType(deduplicate=not args.keep_duplicate_itp).run_system(system)
postprocessing.add(vermouth.MergeChains(chain_set))
postprocessing.add(vermouth.NameMolType(deduplicate=not args.keep_duplicate_itp))
defines = ()

# Apply a rubber band elastic network is required.
Expand All @@ -994,7 +987,7 @@ def entry():
if args.rb_unit == "molecule":
domain_criterion = vermouth.processors.apply_rubber_band.always_true
elif args.rb_unit == "all":
vermouth.MergeAllMolecules().run_system(system)
postprocessing.add(vermouth.MergeAllMolecules())
domain_criterion = vermouth.processors.apply_rubber_band.always_true
elif args.rb_unit == "chain":
domain_criterion = vermouth.processors.apply_rubber_band.same_chain
Expand Down Expand Up @@ -1037,7 +1030,7 @@ def entry():
domain_criterion=domain_criterion,
res_min_dist=args.res_min_dist,
)
rubber_band_processor.run_system(system)
postprocessing.add(rubber_band_processor)

# Here we need to add the resids from the PDB back if that is needed
if args.resid_handling == "input":
Expand All @@ -1050,7 +1043,7 @@ def entry():
# model, thus we skip the sorting here altogether.
if not args.govs_includes:
LOGGER.info("Sorting atomids", type="step")
vermouth.SortMoleculeAtoms().run_system(system)
postprocessing.add(vermouth.SortMoleculeAtoms())

LOGGER.info("Writing output.", type="step")
for molecule in system.molecules:
Expand Down Expand Up @@ -1081,12 +1074,15 @@ def entry():
"was used for the full system:",
"".join(ss_sequence),
]
pipeline.add(postprocessing)
print(pipeline)
pipeline.run_system(system)

if args.top_path is not None:
write_gmx_topology(system, args.top_path, defines=defines, header=header)

# Write a PDB file.
vermouth.pdb.write_pdb(system, str(args.outpath), omit_charges=True)
vermouth.pdb.write_pdb(system, args.outpath, omit_charges=True)

# TODO: allow ignoring warnings per class/amount (i.e. ignore 2
# inconsistent-data warnings)
Expand All @@ -1103,7 +1099,7 @@ def entry():
sys.exit(2)
else:
DeferredFileWriter().write()
vermouth.Quoter().run_system(system)
vermouth.Quoter().run_system(None)


if __name__ == "__main__":
Expand Down
34 changes: 34 additions & 0 deletions vermouth/processors/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@
"""
Provides an abstract base class for processors.
"""
from ..log_helpers import StyleAdapter, get_logger

import networkx as nx

LOGGER = StyleAdapter(get_logger(__name__))

class Processor:
"""
An abstract base class for processors. Subclasses must implement a
`run_molecule` method.
"""
def __str__(self):
return self.__class__.__name__

def run_system(self, system):
"""
Process `system`.
Expand Down Expand Up @@ -52,3 +59,30 @@ def run_molecule(self, molecule):
Either the provided molecule, or a brand new one.
"""
raise NotImplementedError


class ProcessorPipeline(nx.DiGraph, Processor):
def __init__(self, /, name=''):
super().__init__()
self.name = name or self.__class__.__name__

@property
def processors(self):
order = nx.topological_sort(self)
for idx in order:
yield self.nodes[idx]['processor']

def add(self, processor):
current = list(self.nodes)
self.add_node(len(current), processor=processor)
for idx in range(len(current)):
self.add_edge(idx, len(current))

def run_system(self, system):
for processor in self.processors:
name = getattr(processor, 'name', None) or processor.__class__.__name__
LOGGER.info(f"Running {name}", type='step')
processor.run_system(system)

def __str__(self):
return "{name}[{members}]".format(name=self.name, members=', '.join(map(str, self.processors)))

0 comments on commit ea1e36f

Please sign in to comment.