Skip to content

Commit

Permalink
implementation dampled coulomb potential
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Aug 18, 2024
1 parent f743d74 commit a454347
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 26 deletions.
44 changes: 40 additions & 4 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ class PostProcessing(torch.nn.Module):
_SUPPORTED_OPERATIONS = [
"normalize",
"from_atom_to_molecule_reduction",
"conserve_integer_charge",
"long_range_electrostatics" "conserve_integer_charge",
]

def __init__(
Expand Down Expand Up @@ -868,6 +868,7 @@ def _initialize_postprocessing(
ScaleValues,
CalculateAtomicSelfEnergy,
ChargeConservation,
LongRangeElectrostaticEnergy,
)

for property, operations in postprocessing_parameter.items():
Expand All @@ -887,9 +888,32 @@ def _initialize_postprocessing(
if property == "per_atom_charge":
if operations.get("conserve", False):
postprocessing_sequence.append(
ChargeConservation(operations["strategy"])
ChargeConservation(operations["conserve_strategy"])
)
prostprocessing_sequence_names.append("conserve_charge")

if operations.get("coulomb_potential", False):
coulomb_potential = operations["coulomb_potential"]
postprocessing_sequence.append(
LongRangeElectrostaticEnergy(
coulomb_potential["electrostatic_strategy"],
coulomb_potential["maximum_interaction_radius"],
)
)
prostprocessing_sequence_names.append("coulomb_potential")

if coulomb_potential.get("from_atom_to_molecule_reduction, False"):
postprocessing_sequence.append(
FromAtomToMoleculeReduction(
per_atom_property_name="per_atom_energy",
index_name="atomic_subsystem_indices",
output_name="per_molecule_energy",
keep_per_atom_property=operations.get(
"keep_per_atom_property", False
),
)
)

elif property == "per_atom_energy":
if operations.get("normalize", False):
(
Expand Down Expand Up @@ -990,6 +1014,10 @@ def forward(self, data: Dict[str, torch.Tensor]):
if property in self._registered_properties:
self.registered_chained_operations[property](data)

# delte pairwise property object before returning
if 'pairwise_properties' in data:
del data['pairwise_properties']

return data


Expand Down Expand Up @@ -1126,7 +1154,10 @@ def prepare_pairwise_properties(self, data):
return self.compute_interacting_pairs.prepare_inputs(data)

def _add_addiontal_properties(
self, data, output: Dict[str, torch.Tensor]
self,
data,
output: Dict[str, torch.Tensor],
pairwise_properties: PairListOutputs,
) -> Dict[str, torch.Tensor]:
"""
Add additional properties to the output dictionary.
Expand All @@ -1144,6 +1175,7 @@ def _add_addiontal_properties(
"""

output["per_molecule_charge"] = data.total_charge
output["pairwise_properties"] = pairwise_properties
return output

def compute(self, data, core_input):
Expand Down Expand Up @@ -1190,7 +1222,11 @@ def forward(self, input_data: NNPInput) -> Dict[str, torch.Tensor]:
input_data, pairwise_properties
) # FIXME: putput and processed_output are currently a dictionary, we really want to change this to a dataclass
# perform postprocessing operations
output = self._add_addiontal_properties(input_data, output)
output = self._add_addiontal_properties(
input_data,
output,
pairwise_properties,
)
processed_output = self.postprocessing(output)
return processed_output

Expand Down
12 changes: 11 additions & 1 deletion modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,20 @@ class PerAtomEnergy(ParametersBase):
keep_per_atom_property: bool = False


class CoulomPotential(ParametersBase):
electrostatic_strategy: str = "coulomb"
maximum_interaction_radius: Union[str, unit.Quantity]

converted_units = field_validator(
"maximum_interaction_radius",
)(_convert_str_to_unit)


class PerAtomCharge(ParametersBase):
conserve: bool = False
strategy: str = "default"
conserve_strategy: str = "default"
keep_per_atom_property: bool = False
coulomb_potential: Optional[CoulomPotential] = None


class ANI2xParameters(ParametersBase):
Expand Down
121 changes: 102 additions & 19 deletions modelforge/potential/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
This module contains utility functions and classes for processing the output of the potential model.
"""

from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Type, Union

import torch
from typing import Dict
from openff.units import unit

from modelforge.dataset.utils import _ATOMIC_NUMBER_TO_ELEMENT

from .models import PairListOutputs


def load_atomic_self_energies(path: str) -> Dict[str, unit.Quantity]:
"""
Expand Down Expand Up @@ -131,13 +137,6 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return data


from dataclasses import dataclass, field
from typing import Dict, Iterator

from openff.units import unit
from modelforge.dataset.utils import _ATOMIC_NUMBER_TO_ELEMENT


@dataclass
class AtomicSelfEnergies:
"""
Expand Down Expand Up @@ -272,16 +271,13 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return data


from typing import Union


class ChargeConservation(torch.nn.Module):
def __init__(self, method="physnet"):
def __init__(self, method="default"):

super().__init__()
self.method = method
if self.method == "physnet":
self.correct_partial_charges = self.physnet_charge_conservation
if self.method == "default":
self.correct_partial_charges = self.default_charge_conservation
else:
raise ValueError(f"Unknown charge conservation method: {self.method}")

Expand Down Expand Up @@ -313,7 +309,7 @@ def forward(
)
return data

def physnet_charge_conservation(
def default_charge_conservation(
self,
per_atom_charge: torch.Tensor,
mol_indices: torch.Tensor,
Expand Down Expand Up @@ -362,10 +358,6 @@ def physnet_charge_conservation(
# Apply the correction to each atom's charge
per_atom_charge_corrected = per_atom_charge + correction_factors[mol_indices]

per_molecule_corrected_charge = torch.zeros_like(per_atom_charge).scatter_add_(
0, mol_indices.long(), per_atom_charge_corrected
)

return per_atom_charge_corrected


Expand Down Expand Up @@ -429,3 +421,94 @@ def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:

data["ase_tensor"] = ase_tensor
return data


class LongRangeElectrostaticEnergy(torch.nn.Module):
def __init__(self, strategy: str, cutoff: unit.Quantity):
"""
Computes the long-range electrostatic energy for a molecular system
based on predicted partial charges and pairwise distances between atoms.
The implementation follows the methodology described in the PhysNet
paper, using a cutoff function to handle long-range interactions.
Parameters
----------
strategy : str
The strategy to be used for computing the long-range electrostatic
energy.
cutoff : unit.Quantity
The cutoff distance beyond which the interactions are not
considered.
Attributes
----------
strategy : str
The strategy for computing long-range interactions.
cutoff_function : nn.Module
The cutoff function applied to the pairwise distances.
"""
super().__init__()
from .utils import CosineAttenuationFunction

self.strategy = strategy
self.cutoff_function = CosineAttenuationFunction(cutoff)

def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass to compute the long-range electrostatic energy.
This function calculates the long-range electrostatic energy by considering
pairwise Coulomb interactions between atoms, applying a cutoff function to
handle long-range interactions.
Parameters
----------
data : Dict[str, torch.Tensor]
Input data containing the following keys:
- 'per_atom_charge': Tensor of shape (N,) with partial charges for each atom.
- 'atomic_subsystem_indices': Tensor indicating the molecule each atom belongs to.
- 'pairwise_properties': Object containing pairwise distances and indices.
Returns
-------
Dict[str, torch.Tensor]
The input data dictionary with an additional key 'long_range_electrostatic_energy'
containing the computed long-range electrostatic energy.
"""
per_atom_charge = data["per_atom_charge"]
mol_indices = data["atomic_subsystem_indices"]
pairwise_properties = data["pairwise_properties"]
idx_i, idx_j = pairwise_properties.pair_indices
pairwise_distances = pairwise_properties.d_ij

# Initialize the long-range electrostatic energy
long_range_energy = torch.zeros_like(per_atom_charge)

# Apply the cutoff function to pairwise distances
phi_2r = self.cutoff_function(2 * pairwise_distances)
chi_r = phi_2r * (1 / torch.sqrt(pairwise_distances**2 + 1)) + (
1 - phi_2r
) * (1 / pairwise_distances)

# Compute the Coulomb interaction term
coulomb_interactions = (
(per_atom_charge[idx_i] * per_atom_charge[idx_j])
* chi_r
/ pairwise_distances
)

# Zero out diagonal terms (self-interaction)
mask = torch.eye(
coulomb_interactions.size(0), device=coulomb_interactions.device
).bool()
coulomb_interactions.masked_fill_(mask, 0)

# Aggregate the interactions for each atom
coulomb_interactions_per_atom = torch.zeros_like(per_atom_charge).scatter_add_(
0, idx_j.long(), coulomb_interactions
)

data["per_atom_electrostatic_energy"] = coulomb_interactions_per_atom

return data
6 changes: 4 additions & 2 deletions modelforge/tests/data/potential_defaults/schnet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ properties_to_featurize = ['atomic_number']
maximum_atomic_number = 101
number_of_per_atom_features = 32


[potential.postprocessing_parameter]
[potential.postprocessing_parameter.per_atom_energy]
normalize = true
Expand All @@ -30,8 +29,11 @@ keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge]
conserve = true
strategy = "physnet"
conserve_strategy = "default"
keep_per_atom_property = true
[potential.postprocessing_parameter.per_atom_charge.coulomb_potential]
electrostatic_strategy = "coulomb"
maximum_interaction_radius = "10.0 angstrom"

[potential.postprocessing_parameter.general_postprocessing_operation]
calculate_molecular_self_energy = true

0 comments on commit a454347

Please sign in to comment.