Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Transformation and NonTransformation subclass TransformationBase #311

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 109 additions & 83 deletions gufe/transformations/transformation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe

import abc
from typing import Optional, Iterable, Union
import json
import warnings
Expand All @@ -13,7 +14,102 @@
from ..mapping import ComponentMapping


class Transformation(GufeTokenizable):
class TransformationBase(GufeTokenizable):
"""Transformation base class.

"""
def __init__(
self,
protocol: Protocol,
name: Optional[str] = None,
):
self._protocol = protocol
self._name = name

@classmethod
def _defaults(cls):
return super()._defaults()

@property
def name(self) -> Optional[str]:
"""
Optional identifier for the transformation; used as part of its hash.

Set this to a unique value if adding multiple, otherwise identical
transformations to the same :class:`AlchemicalNetwork` to avoid
deduplication.
"""
return self._name

@classmethod
def _from_dict(cls, d: dict):
return cls(**d)

@abc.abstractmethod
def create(
self,
*,
extends: Optional[ProtocolDAGResult] = None,
name: Optional[str] = None,
) -> ProtocolDAG:
"""
Returns a ``ProtocolDAG`` executing this ``Transformation.protocol``.
"""
raise NotImplementedError

def gather(
self, protocol_dag_results: Iterable[ProtocolDAGResult]
) -> ProtocolResult:
"""
Gather multiple ``ProtocolDAGResult`` into a single ``ProtocolResult``.

Parameters
----------
protocol_dag_results : Iterable[ProtocolDAGResult]
The ``ProtocolDAGResult`` objects to assemble aggregate quantities
from.

Returns
-------
ProtocolResult
Aggregated results from many ``ProtocolDAGResult`` objects, all from
a given ``Protocol``.

"""
return self.protocol.gather(protocol_dag_results=protocol_dag_results)

def dump(self, file):
"""Dump this Transformation to a JSON file.

Note that this is not space-efficient: for example, any
``Component`` which is used in both ``ChemicalSystem`` objects will be
represented twice in the JSON output.

Parameters
----------
file : Union[PathLike, FileLike]
a pathlike of filelike to save this transformation to.
"""
with ensure_filelike(file, mode='w') as f:
json.dump(self.to_dict(), f, cls=JSON_HANDLER.encoder,
sort_keys=True)

@classmethod
def load(cls, file):
"""Create a Transformation from a JSON file.

Parameters
----------
file : Union[PathLike, FileLike]
a pathlike or filelike to read this transformation from
"""
with ensure_filelike(file, mode='r') as f:
dct = json.load(f, cls=JSON_HANDLER.decoder)

return cls.from_dict(dct)


class Transformation(TransformationBase):
_stateA: ChemicalSystem
_stateB: ChemicalSystem
_name: Optional[str]
Expand Down Expand Up @@ -56,18 +152,14 @@ def __init__(

self._stateA = stateA
self._stateB = stateB
self._protocol = protocol
self._mapping = mapping
self._name = name

self._protocol = protocol

@classmethod
def _defaults(cls):
return super()._defaults()

def __repr__(self):
return f"{self.__class__.__name__}(stateA={self.stateA}, "\
f"stateB={self.stateB}, protocol={self.protocol})"
attrs = ['name', 'stateA', 'stateB', 'protocol', 'mapping']
content = ", ".join([f"{i}={getattr(self, i)}" for i in attrs])
return f"{self.__class__.__name__}({content})"

@property
def stateA(self) -> ChemicalSystem:
Expand Down Expand Up @@ -95,17 +187,6 @@ def mapping(self) -> Optional[Union[ComponentMapping, list[ComponentMapping]]]:
"""The mappings relevant for this Transformation"""
return self._mapping

@property
def name(self) -> Optional[str]:
"""
Optional identifier for the transformation; used as part of its hash.

Set this to a unique value if adding multiple, otherwise identical
transformations to the same :class:`AlchemicalNetwork` to avoid
deduplication.
"""
return self._name

def _to_dict(self) -> dict:
return {
"stateA": self.stateA,
Expand All @@ -115,10 +196,6 @@ def _to_dict(self) -> dict:
"name": self.name,
}

@classmethod
def _from_dict(cls, d: dict):
return cls(**d)

def create(
self,
*,
Expand All @@ -137,60 +214,8 @@ def create(
transformation_key=self.key,
)

def gather(
self, protocol_dag_results: Iterable[ProtocolDAGResult]
) -> ProtocolResult:
"""
Gather multiple ``ProtocolDAGResult`` into a single ``ProtocolResult``.

Parameters
----------
protocol_dag_results : Iterable[ProtocolDAGResult]
The ``ProtocolDAGResult`` objects to assemble aggregate quantities
from.

Returns
-------
ProtocolResult
Aggregated results from many ``ProtocolDAGResult`` objects, all from
a given ``Protocol``.

"""
return self.protocol.gather(protocol_dag_results=protocol_dag_results)

def dump(self, file):
"""Dump this Transformation to a JSON file.

Note that this is not space-efficient: for example, any
``Component`` which is used in both ``ChemicalSystem`` objects will be
represented twice in the JSON output.

Parameters
----------
file : Union[PathLike, FileLike]
a pathlike of filelike to save this transformation to.
"""
with ensure_filelike(file, mode='w') as f:
json.dump(self.to_dict(), f, cls=JSON_HANDLER.encoder,
sort_keys=True)

@classmethod
def load(cls, file):
"""Create a Transformation from a JSON file.

Parameters
----------
file : Union[PathLike, FileLike]
a pathlike or filelike to read this transformation from
"""
with ensure_filelike(file, mode='r') as f:
dct = json.load(f, cls=JSON_HANDLER.decoder)

return cls.from_dict(dct)


# we subclass `Transformation` here for typing simplicity
class NonTransformation(Transformation):
class NonTransformation(TransformationBase):
"""A non-alchemical edge of an alchemical network.

A "transformation" that performs no transformation at all.
Expand All @@ -211,8 +236,13 @@ def __init__(
):

self._system = system
self._name = name
self._protocol = protocol
self._name = name

def __repr__(self):
attrs = ['name', 'system', 'protocol']
content = ", ".join([f"{i}={getattr(self, i)}" for i in attrs])
return f"{self.__class__.__name__}({content})"

@property
def stateA(self):
Expand Down Expand Up @@ -243,18 +273,14 @@ def _to_dict(self) -> dict:
"name": self.name,
}

@classmethod
def _from_dict(cls, d: dict):
return cls(**d)

def create(
self,
*,
extends: Optional[ProtocolDAGResult] = None,
name: Optional[str] = None,
) -> ProtocolDAG:
"""
Returns a ``ProtocolDAG`` executing this ``Transformation.protocol``.
Returns a ``ProtocolDAG`` executing this ``NonTransformation.protocol``.
"""
return self.protocol.create(
stateA=self.system,
Expand Down
Loading