Skip to content

Commit

Permalink
Merge pull request #90 from OpenFreeEnergy/multi_map
Browse files Browse the repository at this point in the history
Allow multiple mappings
  • Loading branch information
jthorton authored Sep 25, 2024
2 parents 75f966e + cc8735c commit 4b9f7bf
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
19 changes: 14 additions & 5 deletions feflow/protocols/nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Adapted from perses: https://github.com/choderalab/perses/blob/protocol-neqcyc/perses/protocols/nonequilibrium_cycling.py

from typing import Optional, List, Dict, Any
from typing import Optional, Any, Union
from collections.abc import Iterable
from itertools import chain

Expand Down Expand Up @@ -988,15 +988,24 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, ComponentMapping]] = None,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None,
extends: Optional[ProtocolDAGResult] = None,
) -> list[ProtocolUnit]:
# Handle parameters
if mapping is None:
raise ValueError("`mapping` is required for this Protocol")
from openfe.protocols.openmm_rfe.equil_rfe_methods import (
_validate_alchemical_components,
)
from openfe.protocols.openmm_utils import system_validation

if extends:
raise NotImplementedError("Can't extend simulations yet")

# Do manual validation until it is part of the protocol
# Get alchemical components & validate them + mapping
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
# raise an error if we have more than one mapping
_validate_alchemical_components(alchem_comps, mapping)
mapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore

# inputs to `ProtocolUnit.__init__` should either be `Gufe` objects
# or JSON-serializable objects
num_cycles = self.settings.num_cycles
Expand Down
4 changes: 2 additions & 2 deletions feflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def mapping_toluene_toluene(toluene):
i: i for i in range(len(toluene.to_rdkit().GetAtoms()))
}
mapping_obj = LigandAtomMapping(
componentA=toluene,
componentB=toluene,
componentA=toluene.copy_with_replacements(name="toluene_a"),
componentB=toluene.copy_with_replacements(name="toluene_b"),
componentA_to_componentB=mapping_toluene_to_toluene,
)
return mapping_obj
Expand Down
41 changes: 37 additions & 4 deletions feflow/tests/test_nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,13 @@ def test_create_execute_gather(
],
)
def test_create_execute_gather_toluene_to_toluene(
self, protocol, toluene_vacuum_system, mapping_toluene_toluene, tmpdir, request
self,
protocol,
toluene_vacuum_system,
mapping_toluene_toluene,
tmpdir,
request,
toluene,
):
"""
Perform 20 independent simulations of the NEQ cycling protocol for the toluene to toluene
Expand All @@ -328,10 +334,16 @@ def test_create_execute_gather_toluene_to_toluene(
import numpy as np

protocol = request.getfixturevalue(protocol)

# rename the components
toluene_state_a = toluene_vacuum_system.copy_with_replacements(
components={"ligand": toluene.copy_with_replacements(name="toluene_a")}
)
toluene_state_b = toluene_vacuum_system.copy_with_replacements(
components={"ligand": toluene.copy_with_replacements(name="toluene_b")}
)
dag = protocol.create(
stateA=toluene_vacuum_system,
stateB=toluene_vacuum_system,
stateA=toluene_state_a,
stateB=toluene_state_b,
name="Toluene vacuum transformation",
mapping=mapping_toluene_toluene,
)
Expand Down Expand Up @@ -495,6 +507,27 @@ def test_failing_partial_charge_assign(

execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch)

def test_error_with_multiple_mappings(
self,
protocol_short,
benzene_vacuum_system,
toluene_vacuum_system,
mapping_benzene_toluene,
):
"""
Make sure that when a list of mappings is passed that an error is raised.
"""

with pytest.raises(
ValueError, match="A single LigandAtomMapping is expected for this Protocol"
):
_ = protocol_short.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
name="Test protocol",
mapping=[mapping_benzene_toluene, mapping_benzene_toluene],
)


class TestSetupUnit:
def test_setup_user_charges(
Expand Down

0 comments on commit 4b9f7bf

Please sign in to comment.