From a45434700c23eee04fa46e38460592a80588c5a6 Mon Sep 17 00:00:00 2001 From: wiederm Date: Sun, 18 Aug 2024 20:03:26 +0200 Subject: [PATCH] implementation dampled coulomb potential --- modelforge/potential/models.py | 44 ++++++- modelforge/potential/parameters.py | 12 +- modelforge/potential/processing.py | 121 +++++++++++++++--- .../tests/data/potential_defaults/schnet.toml | 6 +- 4 files changed, 157 insertions(+), 26 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 7f7d42f0..bb4cb680 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -782,7 +782,7 @@ class PostProcessing(torch.nn.Module): _SUPPORTED_OPERATIONS = [ "normalize", "from_atom_to_molecule_reduction", - "conserve_integer_charge", + "long_range_electrostatics" "conserve_integer_charge", ] def __init__( @@ -868,6 +868,7 @@ def _initialize_postprocessing( ScaleValues, CalculateAtomicSelfEnergy, ChargeConservation, + LongRangeElectrostaticEnergy, ) for property, operations in postprocessing_parameter.items(): @@ -887,9 +888,32 @@ def _initialize_postprocessing( if property == "per_atom_charge": if operations.get("conserve", False): postprocessing_sequence.append( - ChargeConservation(operations["strategy"]) + ChargeConservation(operations["conserve_strategy"]) ) prostprocessing_sequence_names.append("conserve_charge") + + if operations.get("coulomb_potential", False): + coulomb_potential = operations["coulomb_potential"] + postprocessing_sequence.append( + LongRangeElectrostaticEnergy( + coulomb_potential["electrostatic_strategy"], + coulomb_potential["maximum_interaction_radius"], + ) + ) + prostprocessing_sequence_names.append("coulomb_potential") + + if coulomb_potential.get("from_atom_to_molecule_reduction, False"): + postprocessing_sequence.append( + FromAtomToMoleculeReduction( + per_atom_property_name="per_atom_energy", + index_name="atomic_subsystem_indices", + output_name="per_molecule_energy", + keep_per_atom_property=operations.get( + "keep_per_atom_property", False + ), + ) + ) + elif property == "per_atom_energy": if operations.get("normalize", False): ( @@ -990,6 +1014,10 @@ def forward(self, data: Dict[str, torch.Tensor]): if property in self._registered_properties: self.registered_chained_operations[property](data) + # delte pairwise property object before returning + if 'pairwise_properties' in data: + del data['pairwise_properties'] + return data @@ -1126,7 +1154,10 @@ def prepare_pairwise_properties(self, data): return self.compute_interacting_pairs.prepare_inputs(data) def _add_addiontal_properties( - self, data, output: Dict[str, torch.Tensor] + self, + data, + output: Dict[str, torch.Tensor], + pairwise_properties: PairListOutputs, ) -> Dict[str, torch.Tensor]: """ Add additional properties to the output dictionary. @@ -1144,6 +1175,7 @@ def _add_addiontal_properties( """ output["per_molecule_charge"] = data.total_charge + output["pairwise_properties"] = pairwise_properties return output def compute(self, data, core_input): @@ -1190,7 +1222,11 @@ def forward(self, input_data: NNPInput) -> Dict[str, torch.Tensor]: input_data, pairwise_properties ) # FIXME: putput and processed_output are currently a dictionary, we really want to change this to a dataclass # perform postprocessing operations - output = self._add_addiontal_properties(input_data, output) + output = self._add_addiontal_properties( + input_data, + output, + pairwise_properties, + ) processed_output = self.postprocessing(output) return processed_output diff --git a/modelforge/potential/parameters.py b/modelforge/potential/parameters.py index eb22a4d8..7164cecb 100644 --- a/modelforge/potential/parameters.py +++ b/modelforge/potential/parameters.py @@ -136,10 +136,20 @@ class PerAtomEnergy(ParametersBase): keep_per_atom_property: bool = False +class CoulomPotential(ParametersBase): + electrostatic_strategy: str = "coulomb" + maximum_interaction_radius: Union[str, unit.Quantity] + + converted_units = field_validator( + "maximum_interaction_radius", + )(_convert_str_to_unit) + + class PerAtomCharge(ParametersBase): conserve: bool = False - strategy: str = "default" + conserve_strategy: str = "default" keep_per_atom_property: bool = False + coulomb_potential: Optional[CoulomPotential] = None class ANI2xParameters(ParametersBase): diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index e94f41eb..e0aef7f1 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -2,10 +2,16 @@ This module contains utility functions and classes for processing the output of the potential model. """ +from dataclasses import dataclass, field +from typing import Dict, Iterator, List, Type, Union + import torch -from typing import Dict from openff.units import unit +from modelforge.dataset.utils import _ATOMIC_NUMBER_TO_ELEMENT + +from .models import PairListOutputs + def load_atomic_self_energies(path: str) -> Dict[str, unit.Quantity]: """ @@ -131,13 +137,6 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return data -from dataclasses import dataclass, field -from typing import Dict, Iterator - -from openff.units import unit -from modelforge.dataset.utils import _ATOMIC_NUMBER_TO_ELEMENT - - @dataclass class AtomicSelfEnergies: """ @@ -272,16 +271,13 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return data -from typing import Union - - class ChargeConservation(torch.nn.Module): - def __init__(self, method="physnet"): + def __init__(self, method="default"): super().__init__() self.method = method - if self.method == "physnet": - self.correct_partial_charges = self.physnet_charge_conservation + if self.method == "default": + self.correct_partial_charges = self.default_charge_conservation else: raise ValueError(f"Unknown charge conservation method: {self.method}") @@ -313,7 +309,7 @@ def forward( ) return data - def physnet_charge_conservation( + def default_charge_conservation( self, per_atom_charge: torch.Tensor, mol_indices: torch.Tensor, @@ -362,10 +358,6 @@ def physnet_charge_conservation( # Apply the correction to each atom's charge per_atom_charge_corrected = per_atom_charge + correction_factors[mol_indices] - per_molecule_corrected_charge = torch.zeros_like(per_atom_charge).scatter_add_( - 0, mol_indices.long(), per_atom_charge_corrected - ) - return per_atom_charge_corrected @@ -429,3 +421,94 @@ def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: data["ase_tensor"] = ase_tensor return data + + +class LongRangeElectrostaticEnergy(torch.nn.Module): + def __init__(self, strategy: str, cutoff: unit.Quantity): + """ + Computes the long-range electrostatic energy for a molecular system + based on predicted partial charges and pairwise distances between atoms. + + The implementation follows the methodology described in the PhysNet + paper, using a cutoff function to handle long-range interactions. + + Parameters + ---------- + strategy : str + The strategy to be used for computing the long-range electrostatic + energy. + cutoff : unit.Quantity + The cutoff distance beyond which the interactions are not + considered. + + Attributes + ---------- + strategy : str + The strategy for computing long-range interactions. + cutoff_function : nn.Module + The cutoff function applied to the pairwise distances. + """ + super().__init__() + from .utils import CosineAttenuationFunction + + self.strategy = strategy + self.cutoff_function = CosineAttenuationFunction(cutoff) + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass to compute the long-range electrostatic energy. + + This function calculates the long-range electrostatic energy by considering + pairwise Coulomb interactions between atoms, applying a cutoff function to + handle long-range interactions. + + Parameters + ---------- + data : Dict[str, torch.Tensor] + Input data containing the following keys: + - 'per_atom_charge': Tensor of shape (N,) with partial charges for each atom. + - 'atomic_subsystem_indices': Tensor indicating the molecule each atom belongs to. + - 'pairwise_properties': Object containing pairwise distances and indices. + + Returns + ------- + Dict[str, torch.Tensor] + The input data dictionary with an additional key 'long_range_electrostatic_energy' + containing the computed long-range electrostatic energy. + """ + per_atom_charge = data["per_atom_charge"] + mol_indices = data["atomic_subsystem_indices"] + pairwise_properties = data["pairwise_properties"] + idx_i, idx_j = pairwise_properties.pair_indices + pairwise_distances = pairwise_properties.d_ij + + # Initialize the long-range electrostatic energy + long_range_energy = torch.zeros_like(per_atom_charge) + + # Apply the cutoff function to pairwise distances + phi_2r = self.cutoff_function(2 * pairwise_distances) + chi_r = phi_2r * (1 / torch.sqrt(pairwise_distances**2 + 1)) + ( + 1 - phi_2r + ) * (1 / pairwise_distances) + + # Compute the Coulomb interaction term + coulomb_interactions = ( + (per_atom_charge[idx_i] * per_atom_charge[idx_j]) + * chi_r + / pairwise_distances + ) + + # Zero out diagonal terms (self-interaction) + mask = torch.eye( + coulomb_interactions.size(0), device=coulomb_interactions.device + ).bool() + coulomb_interactions.masked_fill_(mask, 0) + + # Aggregate the interactions for each atom + coulomb_interactions_per_atom = torch.zeros_like(per_atom_charge).scatter_add_( + 0, idx_j.long(), coulomb_interactions + ) + + data["per_atom_electrostatic_energy"] = coulomb_interactions_per_atom + + return data diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 38f8ddf8..f750b020 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -21,7 +21,6 @@ properties_to_featurize = ['atomic_number'] maximum_atomic_number = 101 number_of_per_atom_features = 32 - [potential.postprocessing_parameter] [potential.postprocessing_parameter.per_atom_energy] normalize = true @@ -30,8 +29,11 @@ keep_per_atom_property = true [potential.postprocessing_parameter.per_atom_charge] conserve = true -strategy = "physnet" +conserve_strategy = "default" keep_per_atom_property = true +[potential.postprocessing_parameter.per_atom_charge.coulomb_potential] +electrostatic_strategy = "coulomb" +maximum_interaction_radius = "10.0 angstrom" [potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true