diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 53142d86..907a246f 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -208,12 +208,11 @@ def as_jax_namedtuple(self) -> NamedTuple: """Export the dataclass fields and values as a named tuple. Convert pytorch tensors to jax arrays.""" - from dataclasses import dataclass, fields + from dataclasses import fields import collections from modelforge.utils.io import import_ convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax - # from pytorch2jax.pytorch2jax import convert_to_jax NNPInputTuple = collections.namedtuple( "NNPInputTuple", [field.name for field in fields(self)] diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 5285fa4c..7f7d42f0 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -774,8 +774,16 @@ class PostProcessing(torch.nn.Module): to compute per-molecule properties from per-atom properties. """ - _SUPPORTED_PROPERTIES = ["per_atom_energy", "general_postprocessing_operation"] - _SUPPORTED_OPERATIONS = ["normalize", "from_atom_to_molecule_reduction"] + _SUPPORTED_PROPERTIES = [ + "per_atom_energy", + "per_atom_charge", + "general_postprocessing_operation", + ] + _SUPPORTED_OPERATIONS = [ + "normalize", + "from_atom_to_molecule_reduction", + "conserve_integer_charge", + ] def __init__( self, @@ -859,6 +867,7 @@ def _initialize_postprocessing( FromAtomToMoleculeReduction, ScaleValues, CalculateAtomicSelfEnergy, + ChargeConservation, ) for property, operations in postprocessing_parameter.items(): @@ -875,7 +884,13 @@ def _initialize_postprocessing( prostprocessing_sequence_names = [] # for each property parse the requested operations - if property == "per_atom_energy": + if property == "per_atom_charge": + if operations.get("conserve", False): + postprocessing_sequence.append( + ChargeConservation(operations["strategy"]) + ) + prostprocessing_sequence_names.append("conserve_charge") + elif property == "per_atom_energy": if operations.get("normalize", False): ( mean, @@ -1110,6 +1125,27 @@ def prepare_pairwise_properties(self, data): self.compute_interacting_pairs._input_checks(data) return self.compute_interacting_pairs.prepare_inputs(data) + def _add_addiontal_properties( + self, data, output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Add additional properties to the output dictionary. + + Parameters + ---------- + data : Union[NNPInput, NamedTuple] + The input data. + output: Dict[str, torch.Tensor] + The output dictionary to add properties to. + + Returns + ------- + Dict[str, torch.Tensor] + """ + + output["per_molecule_charge"] = data.total_charge + return output + def compute(self, data, core_input): """ Compute the core model's output. @@ -1128,7 +1164,7 @@ def compute(self, data, core_input): """ return self.core_module(data, core_input) - def forward(self, input_data: NNPInput): + def forward(self, input_data: NNPInput) -> Dict[str, torch.Tensor]: """ Executes the forward pass of the model. @@ -1150,8 +1186,11 @@ def forward(self, input_data: NNPInput): # compute all interacting pairs with distances pairwise_properties = self.prepare_pairwise_properties(input_data) # prepare the input for the forward pass - output = self.compute(input_data, pairwise_properties) + output = self.compute( + input_data, pairwise_properties + ) # FIXME: putput and processed_output are currently a dictionary, we really want to change this to a dataclass # perform postprocessing operations + output = self._add_addiontal_properties(input_data, output) processed_output = self.postprocessing(output) return processed_output diff --git a/modelforge/potential/parameters.py b/modelforge/potential/parameters.py index 54e84d89..eb22a4d8 100644 --- a/modelforge/potential/parameters.py +++ b/modelforge/potential/parameters.py @@ -136,6 +136,12 @@ class PerAtomEnergy(ParametersBase): keep_per_atom_property: bool = False +class PerAtomCharge(ParametersBase): + conserve: bool = False + strategy: str = "default" + keep_per_atom_property: bool = False + + class ANI2xParameters(ParametersBase): class CoreParameter(ParametersBase): angle_sections: int @@ -157,6 +163,7 @@ class CoreParameter(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -189,6 +196,7 @@ class Featurization(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -201,11 +209,6 @@ class PostProcessingParameter(ParametersBase): class TensorNetParameters(ParametersBase): class CoreParameter(ParametersBase): - # class Featurization(ParametersBase): - # properties_to_featurize: List[str] - # max_Z: int - # number_of_per_atom_features: int - number_of_per_atom_features: int number_of_interaction_layers: int number_of_radial_basis_functions: int @@ -222,6 +225,7 @@ class CoreParameter(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -254,6 +258,7 @@ class Featurization(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -285,6 +290,7 @@ class Featurization(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) @@ -316,6 +322,7 @@ class Featurization(ParametersBase): class PostProcessingParameter(ParametersBase): per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() general_postprocessing_operation: GeneralPostProcessingOperation = ( GeneralPostProcessingOperation() ) diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index ee7da6a5..e94f41eb 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -123,6 +123,7 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: property_per_molecule = property_per_molecule_zeros.scatter_reduce( 0, indices, per_atom_property, reduce=self.reduction_mode ) + data[self.output_name] = property_per_molecule if self.keep_per_atom_property is False: del data[self.per_atom_property_name] @@ -274,6 +275,100 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: from typing import Union +class ChargeConservation(torch.nn.Module): + def __init__(self, method="physnet"): + + super().__init__() + self.method = method + if self.method == "physnet": + self.correct_partial_charges = self.physnet_charge_conservation + else: + raise ValueError(f"Unknown charge conservation method: {self.method}") + + def forward( + self, + data: Dict[str, torch.Tensor], + ): + """ + Apply charge conservation to partial charges. + + Parameters + ---------- + per_atom_partial_charge : torch.Tensor + Flat tensor of partial charges for all atoms in the batch. + atomic_subsystem_indices : torch.Tensor + Tensor of integers indicating which molecule each atom belongs to. + total_charges : torch.Tensor + Tensor of desired total charges for each molecule. + + Returns + ------- + torch.Tensor + Tensor of corrected partial charges. + """ + data["per_atom_charge_corrected"] = self.correct_partial_charges( + data["per_atom_charge"], + data["atomic_subsystem_indices"], + data["per_molecule_charge"], + ) + return data + + def physnet_charge_conservation( + self, + per_atom_charge: torch.Tensor, + mol_indices: torch.Tensor, + total_charges: torch.Tensor, + ) -> torch.Tensor: + """ + PhysNet charge conservation method based on equation 14 from the PhysNet + paper. + + Correct the partial charges such that their sum matches the desired + total charge for each molecule. + + Parameters + ---------- + partial_charges : torch.Tensor + Flat tensor of partial charges for all atoms in all molecules. + mol_indices : torch.Tensor + Tensor of integers indicating which molecule each atom belongs to. + total_charges : torch.Tensor + Tensor of desired total charges for each molecule. + + Returns + ------- + torch.Tensor + Tensor of corrected partial charges. + """ + # the general approach here is outline in equation 14 in the PhysNet + # paper: the difference between the sum of the predicted partial charges + # and the total charge is calculated and then distributed evenly among + # the predicted partial charges + + # Calculate the sum of partial charges for each molecule + + # for each atom i, calculate the sum of partial charges for all other + predicted_per_molecule_charge = torch.zeros( + total_charges.shape, + dtype=per_atom_charge.dtype, + device=total_charges.device, + ).scatter_add_(0, mol_indices.long(), per_atom_charge) + + # Calculate the correction factor for each molecule + correction_factors = ( + total_charges - predicted_per_molecule_charge + ) / mol_indices.bincount() + + # Apply the correction to each atom's charge + per_atom_charge_corrected = per_atom_charge + correction_factors[mol_indices] + + per_molecule_corrected_charge = torch.zeros_like(per_atom_charge).scatter_add_( + 0, mol_indices.long(), per_atom_charge_corrected + ) + + return per_atom_charge_corrected + + class CalculateAtomicSelfEnergy(torch.nn.Module): """ Calculates the atomic self energy for each molecule. diff --git a/modelforge/tests/data/potential_defaults/ani2x.toml b/modelforge/tests/data/potential_defaults/ani2x.toml index 941f0a22..475eb9d7 100644 --- a/modelforge/tests/data/potential_defaults/ani2x.toml +++ b/modelforge/tests/data/potential_defaults/ani2x.toml @@ -16,7 +16,7 @@ predicted_properties = [ ] [potential.core_parameter.activation_function_parameter] -activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used. +activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used. [potential.core_parameter.activation_function_parameter.activation_function_arguments] alpha = 0.1 @@ -26,3 +26,9 @@ alpha = 0.1 normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +from_atom_to_molecule_reduction = false +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/painn.toml b/modelforge/tests/data/potential_defaults/painn.toml index 91dc7a75..bf78c30a 100644 --- a/modelforge/tests/data/potential_defaults/painn.toml +++ b/modelforge/tests/data/potential_defaults/painn.toml @@ -26,3 +26,9 @@ number_of_per_atom_features = 32 normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +from_atom_to_molecule_reduction = false +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/physnet.toml b/modelforge/tests/data/potential_defaults/physnet.toml index 931d9e74..cd3b7c7f 100644 --- a/modelforge/tests/data/potential_defaults/physnet.toml +++ b/modelforge/tests/data/potential_defaults/physnet.toml @@ -26,3 +26,8 @@ number_of_per_atom_features = 32 normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/sake.toml b/modelforge/tests/data/potential_defaults/sake.toml index 11eeafed..d1b350b7 100644 --- a/modelforge/tests/data/potential_defaults/sake.toml +++ b/modelforge/tests/data/potential_defaults/sake.toml @@ -26,3 +26,9 @@ maximum_atomic_number = 101 normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +from_atom_to_molecule_reduction = false +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 7d84efc1..38f8ddf8 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -29,10 +29,9 @@ from_atom_to_molecule_reduction = true keep_per_atom_property = true [potential.postprocessing_parameter.per_atom_charge] -normalize = true -from_atom_to_molecule_reduction = false +conserve = true +strategy = "physnet" keep_per_atom_property = true - [potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true diff --git a/modelforge/tests/data/potential_defaults/tensornet.toml b/modelforge/tests/data/potential_defaults/tensornet.toml index 09158b5b..efe4cf71 100644 --- a/modelforge/tests/data/potential_defaults/tensornet.toml +++ b/modelforge/tests/data/potential_defaults/tensornet.toml @@ -19,9 +19,16 @@ predicted_properties = [ activation_function_name = "SiLU" [potential.postprocessing_parameter] + [potential.postprocessing_parameter.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true + +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +strategy = "physnet" +keep_per_atom_property = true + [potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 5d764595..e56b5032 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -3,6 +3,106 @@ from modelforge.potential import _Implemented_NNPs from modelforge.dataset import _ImplementedDatasets from modelforge.potential import NeuralNetworkPotentialFactory +import torch +from modelforge.utils.io import import_ + + +def initialize_model(simulation_environment: str, config): + """Initialize the model based on the simulation environment and configuration.""" + return NeuralNetworkPotentialFactory.generate_potential( + use="inference", + simulation_environment=simulation_environment, + potential_parameter=config["potential"], + ) + + +def prepare_input_for_model(nnp_input, model): + """Prepare the input for the model based on the simulation environment.""" + if "JAX" in str(type(model)): + return nnp_input.as_jax_namedtuple() + return nnp_input + + +def validate_output_shapes(output, nr_of_mols): + """Validate the output shapes to ensure they are correct.""" + assert len(output["per_molecule_energy"]) == nr_of_mols + assert "per_atom_energy" in output + assert "per_atom_charge" in output + assert "per_atom_charge_corrected" in output + + +def validate_charge_conservation( + per_molecule_charge: torch.Tensor, + per_molecule_charge_corrected: torch.Tensor, + per_molecule_charge_from_dataset: torch.Tensor, + model_name: str, +): + """Ensure charge conservation by validating the corrected charges.""" + + if "PhysNet".lower() in model_name.lower(): + print( + "Physnet starts with all zero partial charges" + ) # NOTE: I am not sure if this is correct + else: + assert not torch.allclose(per_molecule_charge, per_molecule_charge_corrected) + assert torch.allclose( + per_molecule_charge_from_dataset.to(torch.float32), + per_molecule_charge_corrected, + atol=1e-5, + ) + + +def validate_energy_conservation(output: torch.Tensor): + """Ensure that the total energy is the sum of atomic energies.""" + assert torch.allclose( + output["per_molecule_energy"][0], + output["per_atom_energy"][0:5].sum(dim=0), + atol=1e-5, + ) + assert torch.allclose( + output["per_molecule_energy"][1], + output["per_atom_energy"][5:9].sum(dim=0), + atol=1e-5, + ) + + +def validate_chemical_equivalence(output): + """Ensure that chemically equivalent hydrogens have equal energies.""" + assert torch.allclose( + output["per_atom_energy"][1:4], output["per_atom_energy"][1], atol=1e-4 + ) + assert torch.allclose( + output["per_atom_energy"][6:8], output["per_atom_energy"][6], atol=1e-4 + ) + + +def retrieve_molecular_charges(output, atomic_subsystem_indices): + """Retrieve per-molecule charge from per-atom charges.""" + per_molecule_charge = torch.zeros_like(output["per_molecule_energy"]).index_add_( + 0, atomic_subsystem_indices, output["per_atom_charge"] + ) + per_molecule_charge_corrected = torch.zeros_like( + output["per_molecule_energy"] + ).index_add_(0, atomic_subsystem_indices, output["per_atom_charge_corrected"]) + return per_molecule_charge, per_molecule_charge_corrected + + +def convert_to_pytorch_if_needed(output, nnp_input, model): + """Convert output to PyTorch tensors if the model is in JAX.""" + if "JAX" in str(type(model)): + convert_to_pyt = import_("pytorch2jax").pytorch2jax.convert_to_pyt + output["per_molecule_energy"] = convert_to_pyt(output["per_molecule_energy"]) + output["per_atom_charge"] = convert_to_pyt(output["per_atom_charge"]) + output["per_atom_charge_corrected"] = convert_to_pyt( + output["per_atom_charge_corrected"] + ) + output["per_molecule_charge"] = convert_to_pyt( + output["per_molecule_charge"] + ).to(torch.float32) + atomic_subsystem_indices = convert_to_pyt(nnp_input.atomic_subsystem_indices) + else: + atomic_subsystem_indices = nnp_input.atomic_subsystem_indices + return output, atomic_subsystem_indices def load_configs_into_pydantic_models(potential_name: str, dataset_name: str): @@ -424,59 +524,37 @@ def test_forward_pass( potential_name, simulation_environment, single_batch_with_batchsize_64 ): # this test sends a single batch from different datasets through the model - import torch + # get input and set up model nnp_input = single_batch_with_batchsize_64.nnp_input - - # read default parameters config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] + model = initialize_model(simulation_environment, config) + nnp_input = prepare_input_for_model(nnp_input, model) - # test the forward pass through each of the models - model = NeuralNetworkPotentialFactory.generate_potential( - use="inference", - simulation_environment=simulation_environment, - potential_parameter=config["potential"], - ) - if "JAX" in str(type(model)): - nnp_input = nnp_input.as_jax_namedtuple() - + # perform the forward pass through each of the models output = model(nnp_input) - # test that we get an energie per molecule - assert len(output["per_molecule_energy"]) == nr_of_mols - - # check that per-atom charge/energies are present - assert "per_atom_energy" in output - assert "per_atom_charge" in output - - # the batch consists of methane (CH4) and amamonium (NH3) - # which have chemically equivalent hydrogens at the minimum geometry. - # This has to be reflected in the atomic energies E_i, which - # have to be equal for all hydrogens - if "JAX" not in str(type(model)): - from loguru import logger as log - - # assert that the following tensor has equal values for dim=0 index 1 to 4 and 6 to 8 - - assert torch.allclose( - output["per_atom_energy"][1:4], output["per_atom_energy"][1], atol=1e-4 - ) - assert torch.allclose( - output["per_atom_energy"][6:8], output["per_atom_energy"][6], atol=1e-4 - ) + # validate the output + validate_output_shapes(output, nr_of_mols) + output, atomic_subsystem_indices = convert_to_pytorch_if_needed( + output, nnp_input, model + ) - # make sure that the total energy is \sum E_i - assert torch.allclose( - output["per_molecule_energy"][0], - output["per_atom_energy"][0:5].sum(dim=0), - atol=1e-5, - ) - assert torch.allclose( - output["per_molecule_energy"][1], - output["per_atom_energy"][5:9].sum(dim=0), - atol=1e-5, - ) + # test that charge correction is working + per_molecule_charge, per_molecule_charge_corrected = retrieve_molecular_charges( + output, atomic_subsystem_indices + ) + validate_charge_conservation( + per_molecule_charge, + per_molecule_charge_corrected, + output["per_molecule_charge"], + potential_name, + ) + # check that per-atom energies are correct + if "JAX" not in simulation_environment: + validate_chemical_equivalence(output) + validate_energy_conservation(output) @pytest.mark.parametrize( @@ -980,8 +1058,8 @@ def test_casting(potential_name, single_batch_with_batchsize_64): ) @pytest.mark.parametrize("simulation_environment", ["PyTorch"]) def test_equivariant_energies_and_forces( - potential_name, - simulation_environment, + potential_name: str, + simulation_environment: str, single_batch_with_batchsize_64, equivariance_utils, ):