From afc950add88f477eeb826849c400841a931eda84 Mon Sep 17 00:00:00 2001 From: wiederm Date: Mon, 19 Aug 2024 08:56:52 +0200 Subject: [PATCH] update and bugfix --- modelforge/potential/models.py | 10 ++++---- modelforge/potential/parameters.py | 25 ++++++++----------- modelforge/potential/processing.py | 6 ++--- .../tests/data/potential_defaults/schnet.toml | 4 ++- modelforge/tests/test_models.py | 1 + 5 files changed, 22 insertions(+), 24 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index bb4cb680..a95dad4c 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -905,9 +905,9 @@ def _initialize_postprocessing( if coulomb_potential.get("from_atom_to_molecule_reduction, False"): postprocessing_sequence.append( FromAtomToMoleculeReduction( - per_atom_property_name="per_atom_energy", + per_atom_property_name="per_atom_electrostatic_energy", index_name="atomic_subsystem_indices", - output_name="per_molecule_energy", + output_name="per_molecule_electrostatic_energy", keep_per_atom_property=operations.get( "keep_per_atom_property", False ), @@ -1015,9 +1015,9 @@ def forward(self, data: Dict[str, torch.Tensor]): self.registered_chained_operations[property](data) # delte pairwise property object before returning - if 'pairwise_properties' in data: - del data['pairwise_properties'] - + if "pairwise_properties" in data: + del data["pairwise_properties"] + return data diff --git a/modelforge/potential/parameters.py b/modelforge/potential/parameters.py index 7164cecb..179cc973 100644 --- a/modelforge/potential/parameters.py +++ b/modelforge/potential/parameters.py @@ -4,26 +4,20 @@ from __future__ import annotations +from enum import Enum +from typing import List, Optional, Type, Union + +import torch +from openff.units import unit from pydantic import ( BaseModel, - field_validator, ConfigDict, - model_validator, computed_field, + field_validator, + model_validator, ) -from openff.units import unit -from typing import Union, List, Optional, Type -from modelforge.utils.units import _convert_str_to_unit -from enum import Enum -import torch - - -# needed to typecast to torch.nn.Module - -""" -This module contains pydantic models for storing the parameters of -""" +from modelforge.utils.units import _convert_str_to_unit class CaseInsensitiveEnum(str, Enum): @@ -139,6 +133,9 @@ class PerAtomEnergy(ParametersBase): class CoulomPotential(ParametersBase): electrostatic_strategy: str = "coulomb" maximum_interaction_radius: Union[str, unit.Quantity] + from_atom_to_molecule_reduction: bool = False + keep_per_atom_property: bool = False + converted_units = field_validator( "maximum_interaction_radius", diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index e0aef7f1..2d0ba188 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -504,10 +504,8 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: ).bool() coulomb_interactions.masked_fill_(mask, 0) - # Aggregate the interactions for each atom - coulomb_interactions_per_atom = torch.zeros_like(per_atom_charge).scatter_add_( - 0, idx_j.long(), coulomb_interactions - ) + # Sum over all interactions for each atom + coulomb_interactions_per_atom = coulomb_interactions.sum(dim=1) data["per_atom_electrostatic_energy"] = coulomb_interactions_per_atom diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index f750b020..2d4a4ef9 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -10,7 +10,6 @@ shared_interactions = false predicted_properties = [ { name = "per_atom_energy", type = "scalar" }, { name = "per_atom_charge", type = "scalar" }, - # we could also define per_atom_force to be consistent? ] [potential.core_parameter.activation_function_parameter] @@ -31,9 +30,12 @@ keep_per_atom_property = true conserve = true conserve_strategy = "default" keep_per_atom_property = true + [potential.postprocessing_parameter.per_atom_charge.coulomb_potential] electrostatic_strategy = "coulomb" maximum_interaction_radius = "10.0 angstrom" +from_atom_to_molecule_reduction = true +keep_per_atom_property = true [potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index e56b5032..4e52ee2a 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -29,6 +29,7 @@ def validate_output_shapes(output, nr_of_mols): assert "per_atom_energy" in output assert "per_atom_charge" in output assert "per_atom_charge_corrected" in output + assert "per_atom_electrostatic_energy" in output def validate_charge_conservation(