Skip to content

Commit

Permalink
Add gather --allow-partial, improve related error message (#588)
Browse files Browse the repository at this point in the history
* improve error message

* Add --allow-partial to gather

* use it for the other call as well

* ignore coverage on line that should never happen

* Make it possible to do both RHFE and RBFE
  • Loading branch information
dwhswenson authored Oct 30, 2023
1 parent 844d5de commit 2b762c6
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 16 deletions.
86 changes: 72 additions & 14 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from openfecli import OFECommandPlugin
from openfecli.clicktypes import HyphenAwareChoice
import pathlib
import warnings


def _get_column(val):
Expand Down Expand Up @@ -85,40 +86,90 @@ def legacy_get_type(res_fn):
return 'complex'


def _get_ddgs(legs):
def _generate_bad_legs_error_message(set_vals, ligpair):
expected_rbfe = {'complex', 'solvent'}
expected_rhfe = {'solvent', 'vacuum'}
maybe_rhfe = bool(set_vals & expected_rhfe)
maybe_rbfe = bool(set_vals & expected_rbfe)
if maybe_rhfe and not maybe_rbfe:
msg = (
"This appears to be an RHFE calculation, but we're "
f"missing {expected_rhfe - set_vals} runs for the "
f"edge with ligands {ligpair}."
)
elif maybe_rbfe and not maybe_rhfe:
msg = (
"This appears to be an RBFE calculation, but we're "
f"missing {expected_rbfe - set_vals} runs for the "
f"edge with ligands {ligpair}."
)
elif maybe_rbfe and maybe_rhfe:
msg = (
"Unable to determine whether this is an RBFE "
f"or an RHFE calculation. Found legs {set_vals} "
f"for ligands {ligpair}. Those ligands are missing one "
f"of: {(expected_rhfe | expected_rbfe) - set_vals}."
)
else: # -no-cov-
# this should never happen
msg = (
"Something went very wrong while determining the type "
f"of RFE calculation. For the ligand pair {ligpair}, "
f"we found legs labelled {set_vals}. We expected either "
f"{expected_rhfe} or {expected_rbfe}."
)

msg += (
"\n\nYou can force partial gathering of results, without "
"problematic edges, by using the --allow-partial flag of the gather "
"command. Note that this may cause problems with predicting "
"absolute free energies from the relative free energies."
)
return msg


def _get_ddgs(legs, error_on_missing=True):
import numpy as np
DDGs = []
for ligpair, vals in sorted(legs.items()):
set_vals = set(vals)
DDGbind = None
DDGhyd = None
bind_unc = None
hyd_unc = None

if 'complex' in vals and 'solvent' in vals:
do_rbfe = (len(set_vals & {'complex', 'solvent'}) == 2)
do_rhfe = (len(set_vals & {'vacuum', 'solvent'}) == 2)

if do_rbfe:
DG1_mag, DG1_unc = vals['complex']
DG2_mag, DG2_unc = vals['solvent']
if not ((DG1_mag is None) or (DG2_mag is None)):
# DDG(2,1)bind = DG(1->2)complex - DG(1->2)solvent
DDGbind = (DG1_mag - DG2_mag).m
bind_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))
elif 'solvent' in vals and 'vacuum' in vals:

if do_rhfe:
DG1_mag, DG1_unc = vals['solvent']
DG2_mag, DG2_unc = vals['vacuum']
if not ((DG1_mag is None) or (DG2_mag is None)):
DDGhyd = (DG1_mag - DG2_mag).m
hyd_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))
else:
raise RuntimeError("Unable to determine type of RFE calculation "
f"for edges with labels {list(vals)} for "
f"ligands {ligpair}")

if not do_rbfe and not do_rhfe:
msg = _generate_bad_legs_error_message(set_vals, ligpair)
if error_on_missing:
raise RuntimeError(msg)
else:
warnings.warn(msg)

DDGs.append((*ligpair, DDGbind, bind_unc, DDGhyd, hyd_unc))

return DDGs


def _write_ddg(legs, writer):
DDGs = _get_ddgs(legs)
def _write_ddg(legs, writer, allow_partial):
DDGs = _get_ddgs(legs, error_on_missing=not allow_partial)
writer.writerow(["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)",
"uncertainty (kcal/mol)"])
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
Expand All @@ -133,7 +184,7 @@ def _write_ddg(legs, writer):
writer.writerow([ligA, ligB, DDGhyd, hyd_unc])


def _write_dg_raw(legs, writer):
def _write_dg_raw(legs, writer, allow_partial):
writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)",
"uncertainty (kcal/mol)"])
for ligpair, vals in sorted(legs.items()):
Expand All @@ -146,11 +197,11 @@ def _write_dg_raw(legs, writer):
writer.writerow([simtype, *ligpair, m, u])


def _write_dg_mle(legs, writer):
def _write_dg_mle(legs, writer, allow_partial):
import networkx as nx
import numpy as np
from cinnabar.stats import mle
DDGs = _get_ddgs(legs)
DDGs = _get_ddgs(legs, error_on_missing=not allow_partial)
MLEs = []
# 4b) perform MLE
g = nx.DiGraph()
Expand Down Expand Up @@ -219,7 +270,14 @@ def _write_dg_mle(legs, writer):
@click.option('output', '-o',
type=click.File(mode='w'),
default='-')
def gather(rootdir, output, report):
@click.option(
'--allow-partial', is_flag=True, default=False,
help=(
"Do not raise errors is results are missing parts for some edges. "
"(Skip those edges and issue warning instead.)"
)
)
def gather(rootdir, output, report, allow_partial):
"""Gather simulation result jsons of relative calculations to a tsv file
This walks ROOTDIR recursively and finds all result JSON files from the
Expand Down Expand Up @@ -287,7 +345,7 @@ def gather(rootdir, output, report):
'ddg': _write_ddg,
'dg-raw': _write_dg_raw,
}[report.lower()]
writing_func(legs, writer)
writing_func(legs, writer, allow_partial)


PLUGIN = OFECommandPlugin(
Expand Down
30 changes: 28 additions & 2 deletions openfecli/tests/commands/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest

from openfecli.commands.gather import (
gather, format_estimate_uncertainty, _get_column
gather, format_estimate_uncertainty, _get_column,
_generate_bad_legs_error_message,
)

@pytest.mark.parametrize('est,unc,unc_prec,est_str,unc_str', [
Expand Down Expand Up @@ -108,6 +109,21 @@ def test_gather(results_dir, report):
assert set(expected.split(b'\n')) == actual_lines


@pytest.mark.parametrize('include', ['complex', 'solvent', 'vacuum'])
def test_generate_bad_legs_error_message(include):
expected = {
'complex': ("appears to be an RBFE", "missing {'solvent'}"),
'vacuum': ("appears to be an RHFE", "missing {'solvent'}"),
'solvent': ("whether this is an RBFE or an RHFE",
"'complex'", "'solvent'"),
}[include]
set_vals = {include}
ligpair = {'lig1', 'lig2'}
msg = _generate_bad_legs_error_message(set_vals, ligpair)
for string in expected:
assert string in msg


def test_missing_leg_error(results_dir):
file_to_remove = "easy_rbfe_lig_ejm_31_complex_lig_ejm_42_complex.json"
(pathlib.Path("results") / file_to_remove).unlink()
Expand All @@ -116,6 +132,16 @@ def test_missing_leg_error(results_dir):
result = runner.invoke(gather, ['results'] + ['-o', '-'])
assert result.exit_code == 1
assert isinstance(result.exception, RuntimeError)
assert "labels ['solvent']" in str(result.exception)
assert "Unable to determine" in str(result.exception)
assert "'lig_ejm_31'" in str(result.exception)
assert "'lig_ejm_42'" in str(result.exception)


def test_missing_leg_allow_partial(results_dir):
file_to_remove = "easy_rbfe_lig_ejm_31_complex_lig_ejm_42_complex.json"
(pathlib.Path("results") / file_to_remove).unlink()

runner = CliRunner()
result = runner.invoke(gather,
['results'] + ['--allow-partial', '-o', '-'])
assert result.exit_code == 0

0 comments on commit 2b762c6

Please sign in to comment.