Skip to content

Commit

Permalink
Merge pull request #235 from choderalab/dev-electrostatics
Browse files Browse the repository at this point in the history
Long range electrostatics
  • Loading branch information
wiederm authored Sep 17, 2024
2 parents 18c41d6 + 5e4c0df commit 56672f3
Show file tree
Hide file tree
Showing 12 changed files with 457 additions and 136 deletions.
44 changes: 40 additions & 4 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ class PostProcessing(torch.nn.Module):
_SUPPORTED_OPERATIONS = [
"normalize",
"from_atom_to_molecule_reduction",
"conserve_integer_charge",
"long_range_electrostatics" "conserve_integer_charge",
]

def __init__(
Expand Down Expand Up @@ -882,6 +882,7 @@ def _initialize_postprocessing(
ScaleValues,
CalculateAtomicSelfEnergy,
ChargeConservation,
LongRangeElectrostaticEnergy,
)

for property, operations in postprocessing_parameter.items():
Expand All @@ -901,9 +902,32 @@ def _initialize_postprocessing(
if property == "per_atom_charge":
if operations.get("conserve", False):
postprocessing_sequence.append(
ChargeConservation(operations["strategy"])
ChargeConservation(operations["conserve_strategy"])
)
prostprocessing_sequence_names.append("conserve_charge")

if operations.get("coulomb_potential", False):
coulomb_potential = operations["coulomb_potential"]
postprocessing_sequence.append(
LongRangeElectrostaticEnergy(
coulomb_potential["electrostatic_strategy"],
coulomb_potential["maximum_interaction_radius"],
)
)
prostprocessing_sequence_names.append("coulomb_potential")

if coulomb_potential.get("from_atom_to_molecule_reduction, False"):
postprocessing_sequence.append(
FromAtomToMoleculeReduction(
per_atom_property_name="per_atom_electrostatic_energy",
index_name="atomic_subsystem_indices",
output_name="per_molecule_electrostatic_energy",
keep_per_atom_property=operations.get(
"keep_per_atom_property", False
),
)
)

elif property == "per_atom_energy":
if operations.get("normalize", False):
(
Expand Down Expand Up @@ -1004,6 +1028,10 @@ def forward(self, data: Dict[str, torch.Tensor]):
if property in self._registered_properties:
self.registered_chained_operations[property](data)

# delte pairwise property object before returning
if "pairwise_properties" in data:
del data["pairwise_properties"]

return data


Expand Down Expand Up @@ -1163,7 +1191,10 @@ def prepare_pairwise_properties(self, data):
return self.compute_interacting_pairs.prepare_inputs(data)

def _add_addiontal_properties(
self, data, output: Dict[str, torch.Tensor]
self,
data,
output: Dict[str, torch.Tensor],
pairwise_properties: PairListOutputs,
) -> Dict[str, torch.Tensor]:
"""
Add additional properties to the output dictionary.
Expand All @@ -1181,6 +1212,7 @@ def _add_addiontal_properties(
"""

output["per_molecule_charge"] = data.total_charge
output["pairwise_properties"] = pairwise_properties
return output

def compute(self, data, core_input):
Expand Down Expand Up @@ -1227,7 +1259,11 @@ def forward(self, input_data: NNPInput) -> Dict[str, torch.Tensor]:
input_data, pairwise_properties
) # FIXME: putput and processed_output are currently a dictionary, we really want to change this to a dataclass
# perform postprocessing operations
output = self._add_addiontal_properties(input_data, output)
output = self._add_addiontal_properties(
input_data,
output,
pairwise_properties,
)
processed_output = self.postprocessing(output)
return processed_output

Expand Down
36 changes: 21 additions & 15 deletions modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,20 @@

from __future__ import annotations

from enum import Enum
from typing import List, Optional, Type, Union

import torch
from openff.units import unit
from pydantic import (
BaseModel,
field_validator,
ConfigDict,
model_validator,
computed_field,
field_validator,
model_validator,
)
from openff.units import unit
from typing import Union, List, Optional, Type
from modelforge.utils.units import _convert_str_to_unit
from enum import Enum

import torch


# needed to typecast to torch.nn.Module

"""
This module contains pydantic models for storing the parameters of
"""
from modelforge.utils.units import _convert_str_to_unit


class CaseInsensitiveEnum(str, Enum):
Expand Down Expand Up @@ -136,10 +130,22 @@ class PerAtomEnergy(ParametersBase):
keep_per_atom_property: bool = False


class CoulomPotential(ParametersBase):
electrostatic_strategy: str = "coulomb"
maximum_interaction_radius: Union[str, unit.Quantity]
from_atom_to_molecule_reduction: bool = False
keep_per_atom_property: bool = False

converted_units = field_validator(
"maximum_interaction_radius",
)(_convert_str_to_unit)


class PerAtomCharge(ParametersBase):
conserve: bool = False
strategy: str = "default"
conserve_strategy: str = "default"
keep_per_atom_property: bool = False
coulomb_potential: Optional[CoulomPotential] = None


class ANI2xParameters(ParametersBase):
Expand Down
Loading

0 comments on commit 56672f3

Please sign in to comment.