Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Charge conservation #234

Merged
merged 5 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,11 @@ def as_jax_namedtuple(self) -> NamedTuple:
"""Export the dataclass fields and values as a named tuple.
Convert pytorch tensors to jax arrays."""

from dataclasses import dataclass, fields
from dataclasses import fields
import collections
from modelforge.utils.io import import_

convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax
# from pytorch2jax.pytorch2jax import convert_to_jax

NNPInputTuple = collections.namedtuple(
"NNPInputTuple", [field.name for field in fields(self)]
Expand Down
49 changes: 44 additions & 5 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,16 @@ class PostProcessing(torch.nn.Module):
to compute per-molecule properties from per-atom properties.
"""

_SUPPORTED_PROPERTIES = ["per_atom_energy", "general_postprocessing_operation"]
_SUPPORTED_OPERATIONS = ["normalize", "from_atom_to_molecule_reduction"]
_SUPPORTED_PROPERTIES = [
"per_atom_energy",
"per_atom_charge",
"general_postprocessing_operation",
]
_SUPPORTED_OPERATIONS = [
"normalize",
"from_atom_to_molecule_reduction",
"conserve_integer_charge",
]

def __init__(
self,
Expand Down Expand Up @@ -859,6 +867,7 @@ def _initialize_postprocessing(
FromAtomToMoleculeReduction,
ScaleValues,
CalculateAtomicSelfEnergy,
ChargeConservation,
)

for property, operations in postprocessing_parameter.items():
Expand All @@ -875,7 +884,13 @@ def _initialize_postprocessing(
prostprocessing_sequence_names = []

# for each property parse the requested operations
if property == "per_atom_energy":
if property == "per_atom_charge":
if operations.get("conserve", False):
postprocessing_sequence.append(
ChargeConservation(operations["strategy"])
)
prostprocessing_sequence_names.append("conserve_charge")
elif property == "per_atom_energy":
if operations.get("normalize", False):
(
mean,
Expand Down Expand Up @@ -1110,6 +1125,27 @@ def prepare_pairwise_properties(self, data):
self.compute_interacting_pairs._input_checks(data)
return self.compute_interacting_pairs.prepare_inputs(data)

def _add_addiontal_properties(
self, data, output: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Add additional properties to the output dictionary.

Parameters
----------
data : Union[NNPInput, NamedTuple]
The input data.
output: Dict[str, torch.Tensor]
The output dictionary to add properties to.

Returns
-------
Dict[str, torch.Tensor]
"""

output["per_molecule_charge"] = data.total_charge
return output

def compute(self, data, core_input):
"""
Compute the core model's output.
Expand All @@ -1128,7 +1164,7 @@ def compute(self, data, core_input):
"""
return self.core_module(data, core_input)

def forward(self, input_data: NNPInput):
def forward(self, input_data: NNPInput) -> Dict[str, torch.Tensor]:
"""
Executes the forward pass of the model.

Expand All @@ -1150,8 +1186,11 @@ def forward(self, input_data: NNPInput):
# compute all interacting pairs with distances
pairwise_properties = self.prepare_pairwise_properties(input_data)
# prepare the input for the forward pass
output = self.compute(input_data, pairwise_properties)
output = self.compute(
input_data, pairwise_properties
) # FIXME: putput and processed_output are currently a dictionary, we really want to change this to a dataclass
# perform postprocessing operations
output = self._add_addiontal_properties(input_data, output)
processed_output = self.postprocessing(output)
return processed_output

Expand Down
17 changes: 12 additions & 5 deletions modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class PerAtomEnergy(ParametersBase):
keep_per_atom_property: bool = False


class PerAtomCharge(ParametersBase):
conserve: bool = False
strategy: str = "default"
keep_per_atom_property: bool = False


class ANI2xParameters(ParametersBase):
class CoreParameter(ParametersBase):
angle_sections: int
Expand All @@ -157,6 +163,7 @@ class CoreParameter(ParametersBase):

class PostProcessingParameter(ParametersBase):
per_atom_energy: PerAtomEnergy = PerAtomEnergy()
per_atom_charge: PerAtomCharge = PerAtomCharge()
general_postprocessing_operation: GeneralPostProcessingOperation = (
GeneralPostProcessingOperation()
)
Expand Down Expand Up @@ -189,6 +196,7 @@ class Featurization(ParametersBase):

class PostProcessingParameter(ParametersBase):
per_atom_energy: PerAtomEnergy = PerAtomEnergy()
per_atom_charge: PerAtomCharge = PerAtomCharge()
general_postprocessing_operation: GeneralPostProcessingOperation = (
GeneralPostProcessingOperation()
)
Expand All @@ -201,11 +209,6 @@ class PostProcessingParameter(ParametersBase):

class TensorNetParameters(ParametersBase):
class CoreParameter(ParametersBase):
# class Featurization(ParametersBase):
# properties_to_featurize: List[str]
# max_Z: int
# number_of_per_atom_features: int

number_of_per_atom_features: int
number_of_interaction_layers: int
number_of_radial_basis_functions: int
Expand All @@ -222,6 +225,7 @@ class CoreParameter(ParametersBase):

class PostProcessingParameter(ParametersBase):
per_atom_energy: PerAtomEnergy = PerAtomEnergy()
per_atom_charge: PerAtomCharge = PerAtomCharge()
general_postprocessing_operation: GeneralPostProcessingOperation = (
GeneralPostProcessingOperation()
)
Expand Down Expand Up @@ -254,6 +258,7 @@ class Featurization(ParametersBase):

class PostProcessingParameter(ParametersBase):
per_atom_energy: PerAtomEnergy = PerAtomEnergy()
per_atom_charge: PerAtomCharge = PerAtomCharge()
general_postprocessing_operation: GeneralPostProcessingOperation = (
GeneralPostProcessingOperation()
)
Expand Down Expand Up @@ -285,6 +290,7 @@ class Featurization(ParametersBase):

class PostProcessingParameter(ParametersBase):
per_atom_energy: PerAtomEnergy = PerAtomEnergy()
per_atom_charge: PerAtomCharge = PerAtomCharge()
general_postprocessing_operation: GeneralPostProcessingOperation = (
GeneralPostProcessingOperation()
)
Expand Down Expand Up @@ -316,6 +322,7 @@ class Featurization(ParametersBase):

class PostProcessingParameter(ParametersBase):
per_atom_energy: PerAtomEnergy = PerAtomEnergy()
per_atom_charge: PerAtomCharge = PerAtomCharge()
general_postprocessing_operation: GeneralPostProcessingOperation = (
GeneralPostProcessingOperation()
)
Expand Down
95 changes: 95 additions & 0 deletions modelforge/potential/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
property_per_molecule = property_per_molecule_zeros.scatter_reduce(
0, indices, per_atom_property, reduce=self.reduction_mode
)

data[self.output_name] = property_per_molecule
if self.keep_per_atom_property is False:
del data[self.per_atom_property_name]
Expand Down Expand Up @@ -274,6 +275,100 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
from typing import Union


class ChargeConservation(torch.nn.Module):
def __init__(self, method="physnet"):

super().__init__()
self.method = method
if self.method == "physnet":
self.correct_partial_charges = self.physnet_charge_conservation
else:
raise ValueError(f"Unknown charge conservation method: {self.method}")

def forward(
self,
data: Dict[str, torch.Tensor],
):
"""
Apply charge conservation to partial charges.

Parameters
----------
per_atom_partial_charge : torch.Tensor
Flat tensor of partial charges for all atoms in the batch.
atomic_subsystem_indices : torch.Tensor
Tensor of integers indicating which molecule each atom belongs to.
total_charges : torch.Tensor
Tensor of desired total charges for each molecule.

Returns
-------
torch.Tensor
Tensor of corrected partial charges.
"""
data["per_atom_charge_corrected"] = self.correct_partial_charges(
data["per_atom_charge"],
data["atomic_subsystem_indices"],
data["per_molecule_charge"],
)
return data

def physnet_charge_conservation(
self,
per_atom_charge: torch.Tensor,
mol_indices: torch.Tensor,
total_charges: torch.Tensor,
) -> torch.Tensor:
"""
PhysNet charge conservation method based on equation 14 from the PhysNet
paper.

Correct the partial charges such that their sum matches the desired
total charge for each molecule.

Parameters
----------
partial_charges : torch.Tensor
Flat tensor of partial charges for all atoms in all molecules.
mol_indices : torch.Tensor
Tensor of integers indicating which molecule each atom belongs to.
total_charges : torch.Tensor
Tensor of desired total charges for each molecule.

Returns
-------
torch.Tensor
Tensor of corrected partial charges.
"""
# the general approach here is outline in equation 14 in the PhysNet
# paper: the difference between the sum of the predicted partial charges
# and the total charge is calculated and then distributed evenly among
# the predicted partial charges

# Calculate the sum of partial charges for each molecule

# for each atom i, calculate the sum of partial charges for all other
predicted_per_molecule_charge = torch.zeros(
total_charges.shape,
dtype=per_atom_charge.dtype,
device=total_charges.device,
).scatter_add_(0, mol_indices.long(), per_atom_charge)

# Calculate the correction factor for each molecule
correction_factors = (
total_charges - predicted_per_molecule_charge
) / mol_indices.bincount()

# Apply the correction to each atom's charge
per_atom_charge_corrected = per_atom_charge + correction_factors[mol_indices]

per_molecule_corrected_charge = torch.zeros_like(per_atom_charge).scatter_add_(
0, mol_indices.long(), per_atom_charge_corrected
)

return per_atom_charge_corrected


class CalculateAtomicSelfEnergy(torch.nn.Module):
"""
Calculates the atomic self energy for each molecule.
Expand Down
8 changes: 7 additions & 1 deletion modelforge/tests/data/potential_defaults/ani2x.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ predicted_properties = [
]

[potential.core_parameter.activation_function_parameter]
activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used.
activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used.

[potential.core_parameter.activation_function_parameter.activation_function_arguments]
alpha = 0.1
Expand All @@ -26,3 +26,9 @@ alpha = 0.1
normalize = true
from_atom_to_molecule_reduction = true
keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge]
conserve = true
strategy = "physnet"
from_atom_to_molecule_reduction = false
keep_per_atom_property = true
6 changes: 6 additions & 0 deletions modelforge/tests/data/potential_defaults/painn.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@ number_of_per_atom_features = 32
normalize = true
from_atom_to_molecule_reduction = true
keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge]
conserve = true
strategy = "physnet"
from_atom_to_molecule_reduction = false
keep_per_atom_property = true
5 changes: 5 additions & 0 deletions modelforge/tests/data/potential_defaults/physnet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ number_of_per_atom_features = 32
normalize = true
from_atom_to_molecule_reduction = true
keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge]
conserve = true
strategy = "physnet"
keep_per_atom_property = true
6 changes: 6 additions & 0 deletions modelforge/tests/data/potential_defaults/sake.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@ maximum_atomic_number = 101
normalize = true
from_atom_to_molecule_reduction = true
keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge]
conserve = true
strategy = "physnet"
from_atom_to_molecule_reduction = false
keep_per_atom_property = true
5 changes: 2 additions & 3 deletions modelforge/tests/data/potential_defaults/schnet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ from_atom_to_molecule_reduction = true
keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge]
normalize = true
from_atom_to_molecule_reduction = false
conserve = true
strategy = "physnet"
keep_per_atom_property = true


[potential.postprocessing_parameter.general_postprocessing_operation]
calculate_molecular_self_energy = true
7 changes: 7 additions & 0 deletions modelforge/tests/data/potential_defaults/tensornet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@ predicted_properties = [
activation_function_name = "SiLU"

[potential.postprocessing_parameter]

[potential.postprocessing_parameter.per_atom_energy]
normalize = true
from_atom_to_molecule_reduction = true
keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge]
conserve = true
strategy = "physnet"
keep_per_atom_property = true

[potential.postprocessing_parameter.general_postprocessing_operation]
calculate_molecular_self_energy = true
Loading