Skip to content

Commit

Permalink
update and bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Aug 19, 2024
1 parent a454347 commit afc950a
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 24 deletions.
10 changes: 5 additions & 5 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,9 +905,9 @@ def _initialize_postprocessing(
if coulomb_potential.get("from_atom_to_molecule_reduction, False"):
postprocessing_sequence.append(
FromAtomToMoleculeReduction(
per_atom_property_name="per_atom_energy",
per_atom_property_name="per_atom_electrostatic_energy",
index_name="atomic_subsystem_indices",
output_name="per_molecule_energy",
output_name="per_molecule_electrostatic_energy",
keep_per_atom_property=operations.get(
"keep_per_atom_property", False
),
Expand Down Expand Up @@ -1015,9 +1015,9 @@ def forward(self, data: Dict[str, torch.Tensor]):
self.registered_chained_operations[property](data)

# delte pairwise property object before returning
if 'pairwise_properties' in data:
del data['pairwise_properties']
if "pairwise_properties" in data:
del data["pairwise_properties"]

return data


Expand Down
25 changes: 11 additions & 14 deletions modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,20 @@

from __future__ import annotations

from enum import Enum
from typing import List, Optional, Type, Union

import torch
from openff.units import unit
from pydantic import (
BaseModel,
field_validator,
ConfigDict,
model_validator,
computed_field,
field_validator,
model_validator,
)
from openff.units import unit
from typing import Union, List, Optional, Type
from modelforge.utils.units import _convert_str_to_unit
from enum import Enum

import torch


# needed to typecast to torch.nn.Module

"""
This module contains pydantic models for storing the parameters of
"""
from modelforge.utils.units import _convert_str_to_unit


class CaseInsensitiveEnum(str, Enum):
Expand Down Expand Up @@ -139,6 +133,9 @@ class PerAtomEnergy(ParametersBase):
class CoulomPotential(ParametersBase):
electrostatic_strategy: str = "coulomb"
maximum_interaction_radius: Union[str, unit.Quantity]
from_atom_to_molecule_reduction: bool = False
keep_per_atom_property: bool = False


converted_units = field_validator(
"maximum_interaction_radius",
Expand Down
6 changes: 2 additions & 4 deletions modelforge/potential/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,8 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
).bool()
coulomb_interactions.masked_fill_(mask, 0)

# Aggregate the interactions for each atom
coulomb_interactions_per_atom = torch.zeros_like(per_atom_charge).scatter_add_(
0, idx_j.long(), coulomb_interactions
)
# Sum over all interactions for each atom
coulomb_interactions_per_atom = coulomb_interactions.sum(dim=1)

data["per_atom_electrostatic_energy"] = coulomb_interactions_per_atom

Expand Down
4 changes: 3 additions & 1 deletion modelforge/tests/data/potential_defaults/schnet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ shared_interactions = false
predicted_properties = [
{ name = "per_atom_energy", type = "scalar" },
{ name = "per_atom_charge", type = "scalar" },
# we could also define per_atom_force to be consistent?
]

[potential.core_parameter.activation_function_parameter]
Expand All @@ -31,9 +30,12 @@ keep_per_atom_property = true
conserve = true
conserve_strategy = "default"
keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge.coulomb_potential]
electrostatic_strategy = "coulomb"
maximum_interaction_radius = "10.0 angstrom"
from_atom_to_molecule_reduction = true
keep_per_atom_property = true

[potential.postprocessing_parameter.general_postprocessing_operation]
calculate_molecular_self_energy = true
1 change: 1 addition & 0 deletions modelforge/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def validate_output_shapes(output, nr_of_mols):
assert "per_atom_energy" in output
assert "per_atom_charge" in output
assert "per_atom_charge_corrected" in output
assert "per_atom_electrostatic_energy" in output


def validate_charge_conservation(
Expand Down

0 comments on commit afc950a

Please sign in to comment.