Skip to content

Commit

Permalink
Merge pull request #175 from StollLab/forcefield
Browse files Browse the repository at this point in the history
change how energy funcs work to improve versitility
  • Loading branch information
mtessmer authored Dec 26, 2024
2 parents 7441ffa + 0420197 commit bb1a859
Show file tree
Hide file tree
Showing 16 changed files with 296 additions and 453 deletions.

Large diffs are not rendered by default.

25 changes: 7 additions & 18 deletions src/chilife/RotamerEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .MolSys import MolSys, MolecularSystemBase
from .MolSysIC import MolSysIC

default_energy_func = scoring.ljEnergyFunc()

class RotamerEnsemble:
"""Create new RotamerEnsemble object.
Expand Down Expand Up @@ -160,20 +161,12 @@ def __init__(self, res, site=None, protein=None, chain=None, rotlib=None, **kwar
if self.clash_radius is None:
self.clash_radius = np.linalg.norm(self.clash_ori - self.coords, axis=-1).max() + 5

if isinstance(self.forcefield, str):
self.forcefield = scoring.ForceField(self.forcefield)

elif not isinstance(self.forcefield, scoring.ForceField):
raise RuntimeError('The kwarg `forcefield` must be a string or ForceField object.')

# Parse important indices
self.aln_idx = np.squeeze(np.argwhere(np.isin(self.atom_names, self.aln_atoms)))
self.backbone_idx = np.squeeze(np.argwhere(np.isin(self.atom_names, self.backbone_atoms)))
self.side_chain_idx = np.argwhere(np.isin(self.atom_names, self.backbone_atoms, invert=True)).flatten()

self._graph = ig.Graph(edges=self.bonds)

_, self.irmin_ij, self.ieps_ij, _ = scoring.prep_internal_clash(self)
self.aidx, self.bidx = [list(x) for x in zip(*self.non_bonded)]

# Allocate variables for clash evaluations
Expand All @@ -188,11 +181,6 @@ def __init__(self, res, site=None, protein=None, chain=None, rotlib=None, **kwar
if self.chain not in ('A', None):
self.name += f"_{self.chain}"

# Create arrays of LJ potential params
if len(self.side_chain_idx) > 0:
self.rmin2 = self.forcefield.get_lj_rmin(self.atom_types[self.side_chain_idx])
self.eps = self.forcefield.get_lj_eps(self.atom_types[self.side_chain_idx])

self.update(no_lib=True)

# Store atom information as atom objects
Expand Down Expand Up @@ -450,6 +438,9 @@ def protein_setup(self):
protein_clash_idx = self.protein_tree.query_ball_point(self.clash_ori, self.clash_radius)
self.protein_clash_idx = [idx for idx in protein_clash_idx if idx not in self.clash_ignore_idx]

if hasattr(self.energy_func, 'prepare_system'):
self.energy_func.prepare_system(self)

if self._coords.shape[1] == len(self.clash_ignore_idx):
RMSDs = np.linalg.norm(
self._coords - self.protein.atoms[self.clash_ignore_idx].positions[None, :, :],
Expand Down Expand Up @@ -1320,10 +1311,10 @@ def intra_fit(self):
def get_sasa(self):
"""Calculate the solvent accessible surface area (SASA) of each rotamer in the protein environment."""

atom_radii = self.forcefield.get_lj_rmin(self.atom_types)
atom_radii = self.energy_func.get_lj_rmin(self.atom_types)
if self.protein is not None:
environment_coords = self.protein.atoms[self.protein_clash_idx].positions
environment_radii = self.forcefield.get_lj_rmin(self.protein.atoms[self.protein_clash_idx].types)
environment_radii = self.energy_func.get_lj_rmin(self.protein.atoms[self.protein_clash_idx].types)
else:
environment_coords = np.empty((0, 3))
environment_radii = np.empty(0)
Expand Down Expand Up @@ -1380,20 +1371,18 @@ def assign_defaults(kwargs):
# Default parameters
defaults = {
"protein_tree": None,
"forgive": 1.0,
"temp": 298,
"clash_radius": None,
"_clash_ori_inp": kwargs.pop("clash_ori", "cen"),
"alignment_method": "bisect",
"dihedral_sigmas": 35,
"weighted_sampling": False,
"forcefield": 'charmm',
"eval_clash": True if not kwargs.get('minimize', False) else False,
"use_H": False,
'_match_backbone': True,
"_exclude_nb_interactions": kwargs.pop('exclude_nb_interactions', 3),
"_sample_size": kwargs.pop("sample", False),
"energy_func": get_lj_rep,
"energy_func": default_energy_func,
"_minimize": kwargs.pop('minimize', False),
"min_method": 'L-BFGS-B',
"_do_trim": kwargs.pop('trim', True),
Expand Down
13 changes: 2 additions & 11 deletions src/chilife/SpinLabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from scipy.spatial.distance import cdist

from .scoring import ForceField, get_lj_energy
from .scoring import ljEnergyFunc, get_lj_energy
from .RotamerEnsemble import RotamerEnsemble


Expand Down Expand Up @@ -73,19 +73,12 @@ def from_mmm(cls, label, site=None, protein=None, chain=None, **kwargs):
"I1M": 12.952083029729994 + 4,
}

# Set forcefield
ff = kwargs.pop('forcefield', 'uff')
if isinstance(ff, str):
ff = ForceField(ff)

clash_radius = kwargs.pop("clash_radius", MMM_maxdist.get(label, None))
alignment_method = kwargs.pop("alignment_method", "mmm")
clash_ori = kwargs.pop("clash_ori", "CA")
energy_func = kwargs.pop(
"energy_func", partial(get_lj_energy, cap=np.inf)
)
energy_func = kwargs.pop('energy_func', ljEnergyFunc(get_lj_energy, 'uff', forgive=0.5, cap=np.inf))
use_H = kwargs.pop("use_H", True)
forgive = kwargs.pop("forgive", 0.5)

# Calculate the SpinLabel
SL = SpinLabel(
Expand All @@ -98,8 +91,6 @@ def from_mmm(cls, label, site=None, protein=None, chain=None, **kwargs):
clash_ori=clash_ori,
energy_func=energy_func,
use_H=use_H,
forgive=forgive,
forcefield = ff,
**kwargs,
)

Expand Down
25 changes: 18 additions & 7 deletions src/chilife/chilife.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .pdb_utils import sort_pdb, get_backbone_atoms, get_bb_candidates
from .protein_utils import mutate, guess_mobile_dihedrals
from .MolSysIC import MolSysIC
from .scoring import get_lj_rep, GAS_CONST, reweight_rotamers
from .scoring import GAS_CONST, reweight_rotamers, ljEnergyFunc
from .numba_utils import get_delta_r, normdist
from .SpinLabel import SpinLabel
from .RotamerEnsemble import RotamerEnsemble
Expand Down Expand Up @@ -165,14 +165,22 @@ def pair_dd(*args, r: ArrayLike, sigma: float = 1.0, use_spin_centers: bool = Tr
weights.append(np.outer(weights1, weights2).flatten())

if dependent:
if SL1.forcefield != SL2.forcefield:
raise RuntimeError('At least two labels passed use different forcefield parameters. Make sure all '
'labels use the same forcefields when setting `dependent=True`')
if not isinstance(SL1.energy_func, ljEnergyFunc):
raise RuntimeError('Currently only ljEnergyFunc objects are supported when using dependent=True')

if SL1.energy_func.join_rmin is not SL2.energy_func.join_rmin or \
SL1.energy_func.join_eps is not SL2.energy_func.join_eps:
raise RuntimeError('At least two labels passed use different energy functions. Make sure all '
'labels use the same energy functions when setting `dependent=True`. This does not'
'mean that the energy functions use the same parameters. They have to be the SAME '
'object and satisfy `SL1.energy_func is SL2.energy_func`. This can be achieved be '
'creating an energy function object')


nrot1, nrot2 = len(SL1), len(SL2)
nat1, nat2 = len(SL1.side_chain_idx), len(SL2.side_chain_idx)
join_rmin = SL1.forcefield.get_lj_rmin("join_protocol")[()]
join_eps = SL1.forcefield.get_lj_eps("join_protocol")[()]
join_rmin = SL1.energy_func.join_rmin
join_eps = SL1.energy_func.join_eps

rmin_ij = join_rmin(SL1.rmin2, SL2.rmin2)
eps_ij = join_eps(SL1.eps, SL2.eps)
Expand Down Expand Up @@ -1431,7 +1439,7 @@ def repack(
*spin_labels: RotamerEnsemble,
repetitions: int = 200,
temp: float = 1,
energy_func: Callable = get_lj_rep,
energy_func: Callable = None,
off_rotamer=False,
**kwargs,
) -> Tuple[mda.Universe, ArrayLike]:
Expand Down Expand Up @@ -1467,6 +1475,9 @@ def repack(
temp = np.atleast_1d(temp)
KT = {t: GAS_CONST * t for t in temp}

energy_func = ljEnergyFunc() if energy_func is None else energy_func


repack_radius = kwargs.pop("repack_radius") if "repack_radius" in kwargs else None # Angstroms
if repack_radius is None:
repack_radius = max([SL.clash_radius for SL in spin_labels])
Expand Down
20 changes: 6 additions & 14 deletions src/chilife/dRotamerEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import scipy.optimize as opt
import MDAnalysis as mda

import chilife.scoring as scoring
import chilife.io as io
import chilife.RotamerEnsemble as re
import chilife.scoring as scoring

from chilife.protein_utils import make_mda_uni
from .MolSysIC import MolSysIC

Expand Down Expand Up @@ -135,9 +136,6 @@ def __init__(self, res, sites, protein=None, chain=None, rotlib=None, **kwargs):
self.input_kwargs = kwargs
self.__dict__.update(dassign_defaults(kwargs))

if isinstance(self.forcefield, str):
self.forcefield = scoring.ForceField(self.forcefield)

self.get_lib(rotlib)
self.create_ensembles()

Expand Down Expand Up @@ -173,9 +171,6 @@ def __init__(self, res, sites, protein=None, chain=None, rotlib=None, **kwargs):
if self.clash_radius is None:
self.clash_radius = np.linalg.norm(self.clash_ori - self.coords, axis=-1).max() + 5

self.rmin2 = self.forcefield.get_lj_rmin(self.atom_types[self.side_chain_idx])
self.eps = self.forcefield.get_lj_eps(self.atom_types[self.side_chain_idx])

self.protein_setup()
self.sub_labels = (self.RE1, self.RE2)

Expand All @@ -190,7 +185,6 @@ def weights(self, value):
self.RE2.weights = value

@property

def coords(self):
"""The 3D cartesian coordinates of each atom of each rotamer in the library."""
ovlp = (self.RE1.coords[:, self.cst_idx1] + self.RE2.coords[:, self.cst_idx2]) / 2
Expand Down Expand Up @@ -377,10 +371,9 @@ def protein_setup(self):
idx for idx in protein_clash_idx if idx not in self.clash_ignore_idx
]

_, self.irmin_ij, self.ieps_ij, _ = scoring.prep_internal_clash(self)
_, self.ermin_ij, self.eeps_ij = scoring.prep_external_clash(self)

self.aidx, self.bidx = [list(x) for x in zip(*self.non_bonded)]
if hasattr(self.energy_func, 'prepare_system'):
self.energy_func.prepare_system(self)

if self._minimize:
self.minimize()
Expand Down Expand Up @@ -753,8 +746,6 @@ def dassign_defaults(kwargs):
# Default parameters
defaults = {
"eval_clash": True,
"forcefield": 'charmm',
"forgive": 0.95,
"temp": 298,
"clash_radius": None,
"protein_tree": None,
Expand All @@ -766,7 +757,8 @@ def dassign_defaults(kwargs):
"use_H": False,
"_exclude_nb_interactions": kwargs.pop('exclude_nb_interactions', 3),

"energy_func": scoring.get_lj_energy,
"energy_func": scoring.ljEnergyFunc(scoring.get_lj_energy, 'charmm', forgive=0.95)
,
"_minimize": kwargs.pop('minimize', True),
"min_method": 'L-BFGS-B',
"_do_trim": kwargs.pop('trim', True),
Expand Down
2 changes: 1 addition & 1 deletion src/chilife/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
rotlib_indexes = pickle.load(f)

with open(DATA_DIR / 'BondDefs.pkl', 'rb') as f:
bond_hmax_dict = {key: (val + 0.4 if 'H' in key else val + 0.35) for key, val in pickle.load(f).items()}
bond_hmax_dict = {key: (val + 0.4 if 'H' in key else val + 0.37) for key, val in pickle.load(f).items()}
bond_hmax_dict = defaultdict(lambda: 0, bond_hmax_dict)


Expand Down
4 changes: 2 additions & 2 deletions src/chilife/protein_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def get_sas_res(
MolSys object to measure Solvent Accessible Surfaces (SAS) area of and report the SAS residues.
cutoff : float
Exclude residues from list with SASA below cutoff in angstroms squared.
forcefield : Union[str, chilife.ForceField]
forcefield : Union[str, chilife.ljParams]
Forcefiled to use defining atom radii for calculating solvent accessibility.
Returns
Expand All @@ -672,7 +672,7 @@ def get_sas_res(
"""
if isinstance(forcefield, str):
forcefield = scoring.ForceField(forcefield)
forcefield = scoring.ljEnergyFunc(forcefield)

environment_coords = protein.atoms.positions
environment_radii = forcefield.get_lj_rmin(protein.atoms.types)
Expand Down
Loading

0 comments on commit bb1a859

Please sign in to comment.