From f9ac633f1edfd7293f2bb326ca73e7e0c0a9504a Mon Sep 17 00:00:00 2001 From: wiederm Date: Sat, 17 Aug 2024 08:52:10 +0200 Subject: [PATCH 1/5] charge conservation --- modelforge/potential/models.py | 21 +++++- modelforge/potential/parameters.py | 17 +++-- modelforge/potential/processing.py | 72 +++++++++++++++++++ .../tests/data/potential_defaults/ani2x.toml | 8 ++- .../tests/data/potential_defaults/painn.toml | 6 ++ .../data/potential_defaults/physnet.toml | 6 ++ .../tests/data/potential_defaults/sake.toml | 6 ++ .../tests/data/potential_defaults/schnet.toml | 5 ++ 8 files changed, 132 insertions(+), 9 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 5285fa4c..b75952c6 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -774,8 +774,16 @@ class PostProcessing(torch.nn.Module): to compute per-molecule properties from per-atom properties. """ - _SUPPORTED_PROPERTIES = ["per_atom_energy", "general_postprocessing_operation"] - _SUPPORTED_OPERATIONS = ["normalize", "from_atom_to_molecule_reduction"] + _SUPPORTED_PROPERTIES = [ + "per_atom_energy", + "per_atom_charge", + "general_postprocessing_operation", + ] + _SUPPORTED_OPERATIONS = [ + "normalize", + "from_atom_to_molecule_reduction", + "conserve_integer_charge", + ] def __init__( self, @@ -859,6 +867,7 @@ def _initialize_postprocessing( FromAtomToMoleculeReduction, ScaleValues, CalculateAtomicSelfEnergy, + ChargeConservation, ) for property, operations in postprocessing_parameter.items(): @@ -875,7 +884,13 @@ def _initialize_postprocessing( prostprocessing_sequence_names = [] # for each property parse the requested operations - if property == "per_atom_energy": + if property == "per_atom_charge": + if operations.get("conserve_integer_charge", False): + postprocessing_sequence.append( + ChargeConservation(operations["strategy"]) + ) + prostprocessing_sequence_names.append("conserve_charge") + elif property == "per_atom_energy": if operations.get("normalize", False): ( mean, diff --git a/modelforge/potential/parameters.py b/modelforge/potential/parameters.py index 54e84d89..eb22a4d8 100644 --- a/modelforge/potential/parameters.py +++ b/modelforge/potential/parameters.py @@ -136,6 +136,12 @@ class PerAtomEnergy(ParametersBase): keep_per_atom_property: bool = False +class PerAtomCharge(ParametersBase): + conserve: bool = False + strategy: str = "default" + keep_per_atom_property: bool = False + + class ANI2xParameters(ParametersBase): class CoreParameter(ParametersBase): angle_sections: int @@ -157,6 +163,7 @@ class CoreParameter(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -189,6 +196,7 @@ class Featurization(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -201,11 +209,6 @@ class PostProcessingParameter(ParametersBase): class TensorNetParameters(ParametersBase): class CoreParameter(ParametersBase): - # class Featurization(ParametersBase): - # properties_to_featurize: List[str] - # max_Z: int - # number_of_per_atom_features: int - number_of_per_atom_features: int number_of_interaction_layers: int number_of_radial_basis_functions: int @@ -222,6 +225,7 @@ class CoreParameter(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -254,6 +258,7 @@ class Featurization(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -285,6 +290,7 @@ class Featurization(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -316,6 +322,7 @@ class Featurization(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index ee7da6a5..237467fd 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -274,6 +274,78 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: from typing import Union +class ChargeConservation(torch.nn.Module): + def __init__(self, method="physnet"): + + super().__init__() + self.method = method + + def forward( + self, + per_atom_partial_charge: torch.Tensor, + atomic_subsystem_indices: torch.Tensor, + total_charges: torch.Tensor, + ): + """ + Apply charge conservation to partial charges. + + Parameters + ---------- + per_atom_partial_charge : torch.Tensor + Flat tensor of partial charges for all atoms in the batch. + atomic_subsystem_indices : torch.Tensor + Tensor of integers indicating which molecule each atom belongs to. + total_charges : torch.Tensor + Tensor of desired total charges for each molecule. + + Returns + ------- + torch.Tensor + Tensor of corrected partial charges. + """ + if self.method == "physnet": + return self.physnet_charge_conservation( + per_atom_partial_charge, atomic_subsystem_indices, total_charges + ) + else: + raise ValueError(f"Unknown charge conservation method: {self.method}") + + def physnet_charge_conservation(self, partial_charges, mol_indices, total_charges): + """ + PhysNet charge conservation method based on equation 14 from the PhysNet + paper. + + Correct the partial charges such that their sum matches the desired + total charge for each molecule. + + Parameters + ---------- + partial_charges : torch.Tensor + Flat tensor of partial charges for all atoms in all molecules. + mol_indices : torch.Tensor + Tensor of integers indicating which molecule each atom belongs to. + total_charges : torch.Tensor + Tensor of desired total charges for each molecule. + + Returns + ------- + torch.Tensor + Tensor of corrected partial charges. + """ + # Calculate the sum of partial charges for each molecule + charge_sums = torch.zeros_like(total_charges).scatter_add_( + 0, mol_indices, partial_charges + ) + + # Calculate the correction factor for each molecule + correction_factors = (total_charges - charge_sums) / mol_indices.bincount() + + # Apply the correction to each atom's charge + corrected_charges = partial_charges + correction_factors[mol_indices] + + return corrected_charges + + class CalculateAtomicSelfEnergy(torch.nn.Module): """ Calculates the atomic self energy for each molecule. diff --git a/modelforge/tests/data/potential_defaults/ani2x.toml b/modelforge/tests/data/potential_defaults/ani2x.toml index 941f0a22..475eb9d7 100644 --- a/modelforge/tests/data/potential_defaults/ani2x.toml +++ b/modelforge/tests/data/potential_defaults/ani2x.toml @@ -16,7 +16,7 @@ predicted_properties = [ ] [potential.core_parameter.activation_function_parameter] -activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used. +activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used. [potential.core_parameter.activation_function_parameter.activation_function_arguments] alpha = 0.1 @@ -26,3 +26,9 @@ alpha = 0.1 normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +from_atom_to_molecule_reduction = false +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/painn.toml b/modelforge/tests/data/potential_defaults/painn.toml index 91dc7a75..bf78c30a 100644 --- a/modelforge/tests/data/potential_defaults/painn.toml +++ b/modelforge/tests/data/potential_defaults/painn.toml @@ -26,3 +26,9 @@ number_of_per_atom_features = 32 normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +from_atom_to_molecule_reduction = false +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/physnet.toml b/modelforge/tests/data/potential_defaults/physnet.toml index 931d9e74..aace1c8a 100644 --- a/modelforge/tests/data/potential_defaults/physnet.toml +++ b/modelforge/tests/data/potential_defaults/physnet.toml @@ -26,3 +26,9 @@ number_of_per_atom_features = 32 normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve_integer_charge = true +strategy = "physnet" +from_atom_to_molecule_reduction = false +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/sake.toml b/modelforge/tests/data/potential_defaults/sake.toml index 11eeafed..d1b350b7 100644 --- a/modelforge/tests/data/potential_defaults/sake.toml +++ b/modelforge/tests/data/potential_defaults/sake.toml @@ -26,3 +26,9 @@ maximum_atomic_number = 101 normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +from_atom_to_molecule_reduction = false +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 7d84efc1..b9836849 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -36,3 +36,8 @@ keep_per_atom_property = true [potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +keep_per_atom_property = true From c099d416ff66f98f48da324fb937cfb4838fb40b Mon Sep 17 00:00:00 2001 From: wiederm Date: Sat, 17 Aug 2024 09:56:06 +0200 Subject: [PATCH 2/5] partial charge correction --- modelforge/potential/models.py | 30 +++++++++++++++++-- modelforge/potential/processing.py | 27 +++++++++-------- .../tests/data/potential_defaults/schnet.toml | 10 ++----- modelforge/tests/test_models.py | 10 +++++++ 4 files changed, 54 insertions(+), 23 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index b75952c6..7f7d42f0 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -885,7 +885,7 @@ def _initialize_postprocessing( # for each property parse the requested operations if property == "per_atom_charge": - if operations.get("conserve_integer_charge", False): + if operations.get("conserve", False): postprocessing_sequence.append( ChargeConservation(operations["strategy"]) ) @@ -1125,6 +1125,27 @@ def prepare_pairwise_properties(self, data): self.compute_interacting_pairs._input_checks(data) return self.compute_interacting_pairs.prepare_inputs(data) + def _add_addiontal_properties( + self, data, output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Add additional properties to the output dictionary. + + Parameters + ---------- + data : Union[NNPInput, NamedTuple] + The input data. + output: Dict[str, torch.Tensor] + The output dictionary to add properties to. + + Returns + ------- + Dict[str, torch.Tensor] + """ + + output["per_molecule_charge"] = data.total_charge + return output + def compute(self, data, core_input): """ Compute the core model's output. @@ -1143,7 +1164,7 @@ def compute(self, data, core_input): """ return self.core_module(data, core_input) - def forward(self, input_data: NNPInput): + def forward(self, input_data: NNPInput) -> Dict[str, torch.Tensor]: """ Executes the forward pass of the model. @@ -1165,8 +1186,11 @@ def forward(self, input_data: NNPInput): # compute all interacting pairs with distances pairwise_properties = self.prepare_pairwise_properties(input_data) # prepare the input for the forward pass - output = self.compute(input_data, pairwise_properties) + output = self.compute( + input_data, pairwise_properties + ) # FIXME: putput and processed_output are currently a dictionary, we really want to change this to a dataclass # perform postprocessing operations + output = self._add_addiontal_properties(input_data, output) processed_output = self.postprocessing(output) return processed_output diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index 237467fd..ad82cc3b 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -123,6 +123,7 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: property_per_molecule = property_per_molecule_zeros.scatter_reduce( 0, indices, per_atom_property, reduce=self.reduction_mode ) + data[self.output_name] = property_per_molecule if self.keep_per_atom_property is False: del data[self.per_atom_property_name] @@ -279,12 +280,14 @@ def __init__(self, method="physnet"): super().__init__() self.method = method + if self.method == "physnet": + self.correct_partial_charges = self.physnet_charge_conservation + else: + raise ValueError(f"Unknown charge conservation method: {self.method}") def forward( self, - per_atom_partial_charge: torch.Tensor, - atomic_subsystem_indices: torch.Tensor, - total_charges: torch.Tensor, + data: Dict[str, torch.Tensor], ): """ Apply charge conservation to partial charges. @@ -303,12 +306,12 @@ def forward( torch.Tensor Tensor of corrected partial charges. """ - if self.method == "physnet": - return self.physnet_charge_conservation( - per_atom_partial_charge, atomic_subsystem_indices, total_charges - ) - else: - raise ValueError(f"Unknown charge conservation method: {self.method}") + data["corrected_partial_charges"] = self.correct_partial_charges( + data["per_atom_charge"], + data["atomic_subsystem_indices"], + data["per_molecule_charge"], + ) + return data def physnet_charge_conservation(self, partial_charges, mol_indices, total_charges): """ @@ -333,9 +336,9 @@ def physnet_charge_conservation(self, partial_charges, mol_indices, total_charge Tensor of corrected partial charges. """ # Calculate the sum of partial charges for each molecule - charge_sums = torch.zeros_like(total_charges).scatter_add_( - 0, mol_indices, partial_charges - ) + charge_sums = torch.zeros_like(total_charges.long()).scatter_add_( + 0, mol_indices.long(), partial_charges.long() + ) # FIXME: why do I have to change the scr/dst to long? # Calculate the correction factor for each molecule correction_factors = (total_charges - charge_sums) / mol_indices.bincount() diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index b9836849..38f8ddf8 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -29,15 +29,9 @@ from_atom_to_molecule_reduction = true keep_per_atom_property = true [potential.postprocessing_parameter.per_atom_charge] -normalize = true -from_atom_to_molecule_reduction = false +conserve = true +strategy = "physnet" keep_per_atom_property = true - [potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true - -[potential.postprocessing_parameter.per_atom_charge] -conserve = true -strategy = "physnet" -keep_per_atom_property = true diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 5d764595..69ec6335 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -450,6 +450,16 @@ def test_forward_pass( assert "per_atom_energy" in output assert "per_atom_charge" in output + # check that the per-atom charges sum to integer charges + per_molecule_charge = torch.zeros_like(output["per_molecule_energy"]).index_add_( + 0, nnp_input.atomic_subsystem_indices, output["per_atom_charge"] + ) + per_molecule_corrected_charge = torch.zeros_like( + output["per_molecule_energy"] + ).index_add_( + 0, nnp_input.atomic_subsystem_indices, output["per_atom_corrected_charge"] + ) + # the batch consists of methane (CH4) and amamonium (NH3) # which have chemically equivalent hydrogens at the minimum geometry. # This has to be reflected in the atomic energies E_i, which From ec89a5567d6e8336cce73725b875e44e5627c675 Mon Sep 17 00:00:00 2001 From: wiederm Date: Sat, 17 Aug 2024 21:23:09 +0200 Subject: [PATCH 3/5] adjusting neural charge equilibration --- modelforge/potential/models.py | 1 + modelforge/potential/processing.py | 38 ++++++++++++++++++++++++------ modelforge/tests/test_models.py | 13 ++++++++-- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 7f7d42f0..a4c258e7 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -1144,6 +1144,7 @@ def _add_addiontal_properties( """ output["per_molecule_charge"] = data.total_charge + output["pair_list"] = data.pair_list return output def compute(self, data, core_input): diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index ad82cc3b..23075039 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -306,14 +306,21 @@ def forward( torch.Tensor Tensor of corrected partial charges. """ - data["corrected_partial_charges"] = self.correct_partial_charges( + data["per_atom_charge_corrected"] = self.correct_partial_charges( data["per_atom_charge"], data["atomic_subsystem_indices"], data["per_molecule_charge"], + data["pair_list"], ) return data - def physnet_charge_conservation(self, partial_charges, mol_indices, total_charges): + def physnet_charge_conservation( + self, + partial_charges: torch.Tensor, + mol_indices: torch.Tensor, + total_charges: torch.Tensor, + pair_list: torch.Tensor, + ): """ PhysNet charge conservation method based on equation 14 from the PhysNet paper. @@ -335,16 +342,33 @@ def physnet_charge_conservation(self, partial_charges, mol_indices, total_charge torch.Tensor Tensor of corrected partial charges. """ + # the general approach here is outline in equation 14 in the PhysNet + # paper: For each atom in a given molecule, we calculate the sum over + # the partial charge of all other atoms in the molecule. The difference + # between the sum of partial charges for all atoms j plus the partial + # charge for atom i and the total charge of the molecule is the + # correction factor that has to be added to the partial charge of atom + # i. + # Calculate the sum of partial charges for each molecule - charge_sums = torch.zeros_like(total_charges.long()).scatter_add_( - 0, mol_indices.long(), partial_charges.long() - ) # FIXME: why do I have to change the scr/dst to long? + partial_charges_for_atom_j = partial_charges[pair_list[1]] + + # for each atom i, calculate the sum of partial charges for all other + per_atom_i_correction = torch.zeros( + partial_charges.shape, dtype=torch.float32, device=total_charges.device + ).scatter_add_(0, pair_list[1].long(), partial_charges_for_atom_j) # Calculate the correction factor for each molecule - correction_factors = (total_charges - charge_sums) / mol_indices.bincount() + correction_factors = ( + total_charges[mol_indices] - per_atom_i_correction + ) / mol_indices.bincount()[mol_indices] # Apply the correction to each atom's charge - corrected_charges = partial_charges + correction_factors[mol_indices] + corrected_charges = partial_charges - correction_factors + + charge_sums_corrected = torch.zeros( + total_charges.shape, dtype=torch.float32, device=total_charges.device + ).scatter_add_(0, mol_indices.long(), corrected_charges) return corrected_charges diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 69ec6335..3b9b0431 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -450,14 +450,23 @@ def test_forward_pass( assert "per_atom_energy" in output assert "per_atom_charge" in output - # check that the per-atom charges sum to integer charges + # retrieve per-molecule charge from per-atom (corrected) charges per_molecule_charge = torch.zeros_like(output["per_molecule_energy"]).index_add_( 0, nnp_input.atomic_subsystem_indices, output["per_atom_charge"] ) per_molecule_corrected_charge = torch.zeros_like( output["per_molecule_energy"] ).index_add_( - 0, nnp_input.atomic_subsystem_indices, output["per_atom_corrected_charge"] + 0, nnp_input.atomic_subsystem_indices, output["per_atom_charge_corrected"] + ) + per_molecule_charge_from_dataset = output["per_molecule_charge"].to(torch.float32) + + # make sure that corrected and uncorrected partial charges differ + assert not torch.allclose(per_molecule_charge, per_molecule_corrected_charge) + + # make sure that the experimental charges are the same as the sum over the corrected partial charges + assert torch.allclose( + per_molecule_charge_from_dataset, per_molecule_corrected_charge ) # the batch consists of methane (CH4) and amamonium (NH3) From 897d1b929c1876fe28b009d1ac2de3a6b8d8d926 Mon Sep 17 00:00:00 2001 From: wiederm Date: Sun, 18 Aug 2024 10:05:25 +0200 Subject: [PATCH 4/5] refactor test, finished charge equiibration, add docs --- modelforge/dataset/dataset.py | 3 +- modelforge/potential/models.py | 1 - modelforge/potential/processing.py | 36 ++-- .../data/potential_defaults/physnet.toml | 3 +- .../data/potential_defaults/tensornet.toml | 7 + modelforge/tests/test_models.py | 187 ++++++++++++------ 6 files changed, 147 insertions(+), 90 deletions(-) diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 53142d86..907a246f 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -208,12 +208,11 @@ def as_jax_namedtuple(self) -> NamedTuple: """Export the dataclass fields and values as a named tuple. Convert pytorch tensors to jax arrays.""" - from dataclasses import dataclass, fields + from dataclasses import fields import collections from modelforge.utils.io import import_ convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax - # from pytorch2jax.pytorch2jax import convert_to_jax NNPInputTuple = collections.namedtuple( "NNPInputTuple", [field.name for field in fields(self)] diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index a4c258e7..7f7d42f0 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -1144,7 +1144,6 @@ def _add_addiontal_properties( """ output["per_molecule_charge"] = data.total_charge - output["pair_list"] = data.pair_list return output def compute(self, data, core_input): diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index 23075039..385d289a 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -310,17 +310,15 @@ def forward( data["per_atom_charge"], data["atomic_subsystem_indices"], data["per_molecule_charge"], - data["pair_list"], ) return data def physnet_charge_conservation( self, - partial_charges: torch.Tensor, + per_atom_charge: torch.Tensor, mol_indices: torch.Tensor, total_charges: torch.Tensor, - pair_list: torch.Tensor, - ): + ) -> torch.Tensor: """ PhysNet charge conservation method based on equation 14 from the PhysNet paper. @@ -343,34 +341,30 @@ def physnet_charge_conservation( Tensor of corrected partial charges. """ # the general approach here is outline in equation 14 in the PhysNet - # paper: For each atom in a given molecule, we calculate the sum over - # the partial charge of all other atoms in the molecule. The difference - # between the sum of partial charges for all atoms j plus the partial - # charge for atom i and the total charge of the molecule is the - # correction factor that has to be added to the partial charge of atom - # i. + # paper: the difference between the sum of the predicted partial charges + # and the total charge is calculated and then distributed evenly among + # the predicted partial charges # Calculate the sum of partial charges for each molecule - partial_charges_for_atom_j = partial_charges[pair_list[1]] # for each atom i, calculate the sum of partial charges for all other - per_atom_i_correction = torch.zeros( - partial_charges.shape, dtype=torch.float32, device=total_charges.device - ).scatter_add_(0, pair_list[1].long(), partial_charges_for_atom_j) + predicted_per_molecule_charge = torch.zeros( + total_charges.shape, dtype=torch.float32, device=total_charges.device + ).scatter_add_(0, mol_indices.long(), per_atom_charge) # Calculate the correction factor for each molecule correction_factors = ( - total_charges[mol_indices] - per_atom_i_correction - ) / mol_indices.bincount()[mol_indices] + total_charges - predicted_per_molecule_charge + ) / mol_indices.bincount() # Apply the correction to each atom's charge - corrected_charges = partial_charges - correction_factors + per_atom_charge_corrected = per_atom_charge + correction_factors[mol_indices] - charge_sums_corrected = torch.zeros( - total_charges.shape, dtype=torch.float32, device=total_charges.device - ).scatter_add_(0, mol_indices.long(), corrected_charges) + per_molecule_corrected_charge = torch.zeros_like(per_atom_charge).scatter_add_( + 0, mol_indices.long(), per_atom_charge_corrected + ) - return corrected_charges + return per_atom_charge_corrected class CalculateAtomicSelfEnergy(torch.nn.Module): diff --git a/modelforge/tests/data/potential_defaults/physnet.toml b/modelforge/tests/data/potential_defaults/physnet.toml index aace1c8a..cd3b7c7f 100644 --- a/modelforge/tests/data/potential_defaults/physnet.toml +++ b/modelforge/tests/data/potential_defaults/physnet.toml @@ -28,7 +28,6 @@ from_atom_to_molecule_reduction = true keep_per_atom_property = true [potential.postprocessing_parameter.per_atom_charge] -conserve_integer_charge = true +conserve = true strategy = "physnet" -from_atom_to_molecule_reduction = false keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/tensornet.toml b/modelforge/tests/data/potential_defaults/tensornet.toml index 09158b5b..efe4cf71 100644 --- a/modelforge/tests/data/potential_defaults/tensornet.toml +++ b/modelforge/tests/data/potential_defaults/tensornet.toml @@ -19,9 +19,16 @@ predicted_properties = [ activation_function_name = "SiLU" [potential.postprocessing_parameter] + [potential.postprocessing_parameter.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +keep_per_atom_property = true + [potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 3b9b0431..e56b5032 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -3,6 +3,106 @@ from modelforge.potential import _Implemented_NNPs from modelforge.dataset import _ImplementedDatasets from modelforge.potential import NeuralNetworkPotentialFactory +import torch +from modelforge.utils.io import import_ + + +def initialize_model(simulation_environment: str, config): + """Initialize the model based on the simulation environment and configuration.""" + return NeuralNetworkPotentialFactory.generate_potential( + use="inference", + simulation_environment=simulation_environment, + potential_parameter=config["potential"], + ) + + +def prepare_input_for_model(nnp_input, model): + """Prepare the input for the model based on the simulation environment.""" + if "JAX" in str(type(model)): + return nnp_input.as_jax_namedtuple() + return nnp_input + + +def validate_output_shapes(output, nr_of_mols): + """Validate the output shapes to ensure they are correct.""" + assert len(output["per_molecule_energy"]) == nr_of_mols + assert "per_atom_energy" in output + assert "per_atom_charge" in output + assert "per_atom_charge_corrected" in output + + +def validate_charge_conservation( + per_molecule_charge: torch.Tensor, + per_molecule_charge_corrected: torch.Tensor, + per_molecule_charge_from_dataset: torch.Tensor, + model_name: str, +): + """Ensure charge conservation by validating the corrected charges.""" + + if "PhysNet".lower() in model_name.lower(): + print( + "Physnet starts with all zero partial charges" + ) # NOTE: I am not sure if this is correct + else: + assert not torch.allclose(per_molecule_charge, per_molecule_charge_corrected) + assert torch.allclose( + per_molecule_charge_from_dataset.to(torch.float32), + per_molecule_charge_corrected, + atol=1e-5, + ) + + +def validate_energy_conservation(output: torch.Tensor): + """Ensure that the total energy is the sum of atomic energies.""" + assert torch.allclose( + output["per_molecule_energy"][0], + output["per_atom_energy"][0:5].sum(dim=0), + atol=1e-5, + ) + assert torch.allclose( + output["per_molecule_energy"][1], + output["per_atom_energy"][5:9].sum(dim=0), + atol=1e-5, + ) + + +def validate_chemical_equivalence(output): + """Ensure that chemically equivalent hydrogens have equal energies.""" + assert torch.allclose( + output["per_atom_energy"][1:4], output["per_atom_energy"][1], atol=1e-4 + ) + assert torch.allclose( + output["per_atom_energy"][6:8], output["per_atom_energy"][6], atol=1e-4 + ) + + +def retrieve_molecular_charges(output, atomic_subsystem_indices): + """Retrieve per-molecule charge from per-atom charges.""" + per_molecule_charge = torch.zeros_like(output["per_molecule_energy"]).index_add_( + 0, atomic_subsystem_indices, output["per_atom_charge"] + ) + per_molecule_charge_corrected = torch.zeros_like( + output["per_molecule_energy"] + ).index_add_(0, atomic_subsystem_indices, output["per_atom_charge_corrected"]) + return per_molecule_charge, per_molecule_charge_corrected + + +def convert_to_pytorch_if_needed(output, nnp_input, model): + """Convert output to PyTorch tensors if the model is in JAX.""" + if "JAX" in str(type(model)): + convert_to_pyt = import_("pytorch2jax").pytorch2jax.convert_to_pyt + output["per_molecule_energy"] = convert_to_pyt(output["per_molecule_energy"]) + output["per_atom_charge"] = convert_to_pyt(output["per_atom_charge"]) + output["per_atom_charge_corrected"] = convert_to_pyt( + output["per_atom_charge_corrected"] + ) + output["per_molecule_charge"] = convert_to_pyt( + output["per_molecule_charge"] + ).to(torch.float32) + atomic_subsystem_indices = convert_to_pyt(nnp_input.atomic_subsystem_indices) + else: + atomic_subsystem_indices = nnp_input.atomic_subsystem_indices + return output, atomic_subsystem_indices def load_configs_into_pydantic_models(potential_name: str, dataset_name: str): @@ -424,78 +524,37 @@ def test_forward_pass( potential_name, simulation_environment, single_batch_with_batchsize_64 ): # this test sends a single batch from different datasets through the model - import torch + # get input and set up model nnp_input = single_batch_with_batchsize_64.nnp_input - - # read default parameters config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] + model = initialize_model(simulation_environment, config) + nnp_input = prepare_input_for_model(nnp_input, model) - # test the forward pass through each of the models - model = NeuralNetworkPotentialFactory.generate_potential( - use="inference", - simulation_environment=simulation_environment, - potential_parameter=config["potential"], - ) - if "JAX" in str(type(model)): - nnp_input = nnp_input.as_jax_namedtuple() - + # perform the forward pass through each of the models output = model(nnp_input) - # test that we get an energie per molecule - assert len(output["per_molecule_energy"]) == nr_of_mols - - # check that per-atom charge/energies are present - assert "per_atom_energy" in output - assert "per_atom_charge" in output - - # retrieve per-molecule charge from per-atom (corrected) charges - per_molecule_charge = torch.zeros_like(output["per_molecule_energy"]).index_add_( - 0, nnp_input.atomic_subsystem_indices, output["per_atom_charge"] + # validate the output + validate_output_shapes(output, nr_of_mols) + output, atomic_subsystem_indices = convert_to_pytorch_if_needed( + output, nnp_input, model ) - per_molecule_corrected_charge = torch.zeros_like( - output["per_molecule_energy"] - ).index_add_( - 0, nnp_input.atomic_subsystem_indices, output["per_atom_charge_corrected"] - ) - per_molecule_charge_from_dataset = output["per_molecule_charge"].to(torch.float32) - # make sure that corrected and uncorrected partial charges differ - assert not torch.allclose(per_molecule_charge, per_molecule_corrected_charge) - - # make sure that the experimental charges are the same as the sum over the corrected partial charges - assert torch.allclose( - per_molecule_charge_from_dataset, per_molecule_corrected_charge + # test that charge correction is working + per_molecule_charge, per_molecule_charge_corrected = retrieve_molecular_charges( + output, atomic_subsystem_indices ) - - # the batch consists of methane (CH4) and amamonium (NH3) - # which have chemically equivalent hydrogens at the minimum geometry. - # This has to be reflected in the atomic energies E_i, which - # have to be equal for all hydrogens - if "JAX" not in str(type(model)): - from loguru import logger as log - - # assert that the following tensor has equal values for dim=0 index 1 to 4 and 6 to 8 - - assert torch.allclose( - output["per_atom_energy"][1:4], output["per_atom_energy"][1], atol=1e-4 - ) - assert torch.allclose( - output["per_atom_energy"][6:8], output["per_atom_energy"][6], atol=1e-4 - ) - - # make sure that the total energy is \sum E_i - assert torch.allclose( - output["per_molecule_energy"][0], - output["per_atom_energy"][0:5].sum(dim=0), - atol=1e-5, - ) - assert torch.allclose( - output["per_molecule_energy"][1], - output["per_atom_energy"][5:9].sum(dim=0), - atol=1e-5, - ) + validate_charge_conservation( + per_molecule_charge, + per_molecule_charge_corrected, + output["per_molecule_charge"], + potential_name, + ) + # check that per-atom energies are correct + if "JAX" not in simulation_environment: + validate_chemical_equivalence(output) + validate_energy_conservation(output) @pytest.mark.parametrize( @@ -999,8 +1058,8 @@ def test_casting(potential_name, single_batch_with_batchsize_64): ) @pytest.mark.parametrize("simulation_environment", ["PyTorch"]) def test_equivariant_energies_and_forces( - potential_name, - simulation_environment, + potential_name: str, + simulation_environment: str, single_batch_with_batchsize_64, equivariance_utils, ): From f743d74d9cf7a48ecb821d3db22c5ab4e78972f0 Mon Sep 17 00:00:00 2001 From: wiederm Date: Sun, 18 Aug 2024 10:07:51 +0200 Subject: [PATCH 5/5] bugfix --- modelforge/potential/processing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index 385d289a..e94f41eb 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -349,7 +349,9 @@ def physnet_charge_conservation( # for each atom i, calculate the sum of partial charges for all other predicted_per_molecule_charge = torch.zeros( - total_charges.shape, dtype=torch.float32, device=total_charges.device + total_charges.shape, + dtype=per_atom_charge.dtype, + device=total_charges.device, ).scatter_add_(0, mol_indices.long(), per_atom_charge) # Calculate the correction factor for each molecule