From fb4a68823ec32baf625a31335ab889be891a836e Mon Sep 17 00:00:00 2001 From: David Dotson Date: Wed, 24 Apr 2024 22:55:42 -0700 Subject: [PATCH 1/2] Make `Transformation` and `NonTransformation` subclass `TransformationBase` In some cases it can be awkward for `NonTransformation` to be a subclass of `Transformation`, such as in `alchemiscale`, for cases where `NonTransformation` should be handled very differently. Switching to a shared, abstract base class for `Transformation` and `NonTransformation` simplifies this. --- gufe/transformations/transformation.py | 189 ++++++++++++++----------- 1 file changed, 106 insertions(+), 83 deletions(-) diff --git a/gufe/transformations/transformation.py b/gufe/transformations/transformation.py index 58a898be..25562cc8 100644 --- a/gufe/transformations/transformation.py +++ b/gufe/transformations/transformation.py @@ -1,6 +1,7 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe +import abc from typing import Optional, Iterable, Union import json import warnings @@ -13,7 +14,95 @@ from ..mapping import ComponentMapping -class Transformation(GufeTokenizable): +class TransformationBase(GufeTokenizable): + """Transformation base class. + + """ + + @classmethod + def _defaults(cls): + return super()._defaults() + + @property + def name(self) -> Optional[str]: + """ + Optional identifier for the transformation; used as part of its hash. + + Set this to a unique value if adding multiple, otherwise identical + transformations to the same :class:`AlchemicalNetwork` to avoid + deduplication. + """ + return self._name + + @classmethod + def _from_dict(cls, d: dict): + return cls(**d) + + @abc.abstractmethod + def create( + self, + *, + extends: Optional[ProtocolDAGResult] = None, + name: Optional[str] = None, + ) -> ProtocolDAG: + """ + Returns a ``ProtocolDAG`` executing this ``Transformation.protocol``. + """ + raise NotImplementedError + + def gather( + self, protocol_dag_results: Iterable[ProtocolDAGResult] + ) -> ProtocolResult: + """ + Gather multiple ``ProtocolDAGResult`` into a single ``ProtocolResult``. + + Parameters + ---------- + protocol_dag_results : Iterable[ProtocolDAGResult] + The ``ProtocolDAGResult`` objects to assemble aggregate quantities + from. + + Returns + ------- + ProtocolResult + Aggregated results from many ``ProtocolDAGResult`` objects, all from + a given ``Protocol``. + + """ + return self.protocol.gather(protocol_dag_results=protocol_dag_results) + + def dump(self, file): + """Dump this Transformation to a JSON file. + + Note that this is not space-efficient: for example, any + ``Component`` which is used in both ``ChemicalSystem`` objects will be + represented twice in the JSON output. + + Parameters + ---------- + file : Union[PathLike, FileLike] + a pathlike of filelike to save this transformation to. + """ + with ensure_filelike(file, mode='w') as f: + json.dump(self.to_dict(), f, cls=JSON_HANDLER.encoder, + sort_keys=True) + + @classmethod + def load(cls, file): + """Create a Transformation from a JSON file. + + Parameters + ---------- + file : Union[PathLike, FileLike] + a pathlike or filelike to read this transformation from + """ + with ensure_filelike(file, mode='r') as f: + dct = json.load(f, cls=JSON_HANDLER.decoder) + + return cls.from_dict(dct) + + +class Transformation(TransformationBase): _stateA: ChemicalSystem _stateB: ChemicalSystem _name: Optional[str] @@ -56,18 +145,14 @@ def __init__( self._stateA = stateA self._stateB = stateB + self._protocol = protocol self._mapping = mapping self._name = name - self._protocol = protocol - - @classmethod - def _defaults(cls): - return super()._defaults() - def __repr__(self): - return f"{self.__class__.__name__}(stateA={self.stateA}, "\ - f"stateB={self.stateB}, protocol={self.protocol})" + attrs = ['name', 'stateA', 'stateB', 'protocol', 'mapping'] + content = ", ".join([f"{i}={getattr(self, i)}" for i in attrs]) + return f"{self.__class__.__name__}({content})" @property def stateA(self) -> ChemicalSystem: @@ -95,17 +180,6 @@ def mapping(self) -> Optional[Union[ComponentMapping, list[ComponentMapping]]]: """The mappings relevant for this Transformation""" return self._mapping - @property - def name(self) -> Optional[str]: - """ - Optional identifier for the transformation; used as part of its hash. - - Set this to a unique value if adding multiple, otherwise identical - transformations to the same :class:`AlchemicalNetwork` to avoid - deduplication. - """ - return self._name - def _to_dict(self) -> dict: return { "stateA": self.stateA, @@ -115,10 +189,6 @@ def _to_dict(self) -> dict: "name": self.name, } - @classmethod - def _from_dict(cls, d: dict): - return cls(**d) - def create( self, *, @@ -137,60 +207,8 @@ def create( transformation_key=self.key, ) - def gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> ProtocolResult: - """ - Gather multiple ``ProtocolDAGResult`` into a single ``ProtocolResult``. - - Parameters - ---------- - protocol_dag_results : Iterable[ProtocolDAGResult] - The ``ProtocolDAGResult`` objects to assemble aggregate quantities - from. - - Returns - ------- - ProtocolResult - Aggregated results from many ``ProtocolDAGResult`` objects, all from - a given ``Protocol``. - - """ - return self.protocol.gather(protocol_dag_results=protocol_dag_results) - - def dump(self, file): - """Dump this Transformation to a JSON file. - - Note that this is not space-efficient: for example, any - ``Component`` which is used in both ``ChemicalSystem`` objects will be - represented twice in the JSON output. - - Parameters - ---------- - file : Union[PathLike, FileLike] - a pathlike of filelike to save this transformation to. - """ - with ensure_filelike(file, mode='w') as f: - json.dump(self.to_dict(), f, cls=JSON_HANDLER.encoder, - sort_keys=True) - - @classmethod - def load(cls, file): - """Create a Transformation from a JSON file. - - Parameters - ---------- - file : Union[PathLike, FileLike] - a pathlike or filelike to read this transformation from - """ - with ensure_filelike(file, mode='r') as f: - dct = json.load(f, cls=JSON_HANDLER.decoder) - - return cls.from_dict(dct) - -# we subclass `Transformation` here for typing simplicity -class NonTransformation(Transformation): +class NonTransformation(TransformationBase): """A non-alchemical edge of an alchemical network. A "transformation" that performs no transformation at all. @@ -211,8 +229,17 @@ def __init__( ): self._system = system - self._name = name self._protocol = protocol + self._name = name + + def __repr__(self): + return f"{self.__class__.__name__}(system={self.system}, "\ + f"protocol={self.protocol})" + + def __repr__(self): + attrs = ['name', 'system', 'protocol'] + content = ", ".join([f"{i}={getattr(self, i)}" for i in attrs]) + return f"{self.__class__.__name__}({content})" @property def stateA(self): @@ -243,10 +270,6 @@ def _to_dict(self) -> dict: "name": self.name, } - @classmethod - def _from_dict(cls, d: dict): - return cls(**d) - def create( self, *, @@ -254,7 +277,7 @@ def create( name: Optional[str] = None, ) -> ProtocolDAG: """ - Returns a ``ProtocolDAG`` executing this ``Transformation.protocol``. + Returns a ``ProtocolDAG`` executing this ``NonTransformation.protocol``. """ return self.protocol.create( stateA=self.system, From 0366d0e6f7b2bda1e95c966b036547d607847dbc Mon Sep 17 00:00:00 2001 From: David Dotson Date: Wed, 24 Apr 2024 23:02:54 -0700 Subject: [PATCH 2/2] Address mypy complaints --- gufe/transformations/transformation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/gufe/transformations/transformation.py b/gufe/transformations/transformation.py index 25562cc8..e018fdc5 100644 --- a/gufe/transformations/transformation.py +++ b/gufe/transformations/transformation.py @@ -18,6 +18,13 @@ class TransformationBase(GufeTokenizable): """Transformation base class. """ + def __init__( + self, + protocol: Protocol, + name: Optional[str] = None, + ): + self._protocol = protocol + self._name = name @classmethod def _defaults(cls): @@ -232,10 +239,6 @@ def __init__( self._protocol = protocol self._name = name - def __repr__(self): - return f"{self.__class__.__name__}(system={self.system}, "\ - f"protocol={self.protocol})" - def __repr__(self): attrs = ['name', 'system', 'protocol'] content = ", ".join([f"{i}={getattr(self, i)}" for i in attrs])