Skip to content

Commit

Permalink
add test for charge conservation
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Sep 16, 2024
1 parent 3f2e67c commit e4892e6
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 53 deletions.
4 changes: 2 additions & 2 deletions modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ class PerAtomEnergy(ParametersBase):

class CoulomPotential(ParametersBase):
electrostatic_strategy: str = "coulomb"
maximum_interaction_radius: Union[str, unit.Quantity]
maximum_coulomb_interaction_radius: Union[str, unit.Quantity]
from_atom_to_molecule_reduction: bool = False
keep_per_atom_property: bool = False

converted_units = field_validator(
"maximum_interaction_radius",
"maximum_coulomb_interaction_radius",
)(_convert_str_to_unit)


Expand Down
102 changes: 51 additions & 51 deletions modelforge/potential/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,64 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
data[self.output_name] = data[self.property] * self.stddev + self.mean
return data

def default_charge_conservation(
per_atom_charge: torch.Tensor,
total_charges: torch.Tensor,
mol_indices: torch.Tensor,
) -> torch.Tensor:
"""
PhysNet charge conservation method based on equation 14 from the PhysNet
paper.
Correct the partial charges such that their sum matches the desired
total charge for each molecule.
Parameters
----------
partial_charges : torch.Tensor
Flat tensor of partial charges for all atoms in all molecules.
total_charges : torch.Tensor
Tensor of desired total charges for each molecule.
mol_indices : torch.Tensor
Tensor of integers indicating which molecule each atom belongs to.
Returns
-------
torch.Tensor
Tensor of corrected partial charges.
"""
# the general approach here is outline in equation 14 in the PhysNet
# paper: the difference between the sum of the predicted partial charges
# and the total charge is calculated and then distributed evenly among
# the predicted partial charges

# Calculate the sum of partial charges for each molecule

# for each atom i, calculate the sum of partial charges for all other
predicted_per_molecule_charge = torch.zeros(
total_charges.shape,
dtype=per_atom_charge.dtype,
device=total_charges.device,
).scatter_add_(0, mol_indices.long(), per_atom_charge)

# Calculate the correction factor for each molecule
correction_factors = (
total_charges - predicted_per_molecule_charge
) / mol_indices.bincount()

# Apply the correction to each atom's charge
per_atom_charge_corrected = per_atom_charge + correction_factors[mol_indices]

return per_atom_charge_corrected


class ChargeConservation(torch.nn.Module):
def __init__(self, method="default"):

super().__init__()
self.method = method
if self.method == "default":
self.correct_partial_charges = self.default_charge_conservation
self.correct_partial_charges = default_charge_conservation
else:
raise ValueError(f"Unknown charge conservation method: {self.method}")

Expand Down Expand Up @@ -309,56 +359,6 @@ def forward(
)
return data

def default_charge_conservation(
self,
per_atom_charge: torch.Tensor,
mol_indices: torch.Tensor,
total_charges: torch.Tensor,
) -> torch.Tensor:
"""
PhysNet charge conservation method based on equation 14 from the PhysNet
paper.
Correct the partial charges such that their sum matches the desired
total charge for each molecule.
Parameters
----------
partial_charges : torch.Tensor
Flat tensor of partial charges for all atoms in all molecules.
mol_indices : torch.Tensor
Tensor of integers indicating which molecule each atom belongs to.
total_charges : torch.Tensor
Tensor of desired total charges for each molecule.
Returns
-------
torch.Tensor
Tensor of corrected partial charges.
"""
# the general approach here is outline in equation 14 in the PhysNet
# paper: the difference between the sum of the predicted partial charges
# and the total charge is calculated and then distributed evenly among
# the predicted partial charges

# Calculate the sum of partial charges for each molecule

# for each atom i, calculate the sum of partial charges for all other
predicted_per_molecule_charge = torch.zeros(
total_charges.shape,
dtype=per_atom_charge.dtype,
device=total_charges.device,
).scatter_add_(0, mol_indices.long(), per_atom_charge)

# Calculate the correction factor for each molecule
correction_factors = (
total_charges - predicted_per_molecule_charge
) / mol_indices.bincount()

# Apply the correction to each atom's charge
per_atom_charge_corrected = per_atom_charge + correction_factors[mol_indices]

return per_atom_charge_corrected


class CalculateAtomicSelfEnergy(torch.nn.Module):
Expand Down
59 changes: 59 additions & 0 deletions modelforge/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,65 @@ def prep_temp_dir(tmp_path_factory):
return fn




def test_charge_equilibration():
from modelforge.potential.processing import default_charge_conservation

# test charge equilibration
# ------------------------- #
# test case 1
partial_point_charges = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
atomic_subsystem_indices = torch.tensor([0, 0, 1, 1, 1, 1], dtype=torch.int64)
total_charge = torch.tensor([0.0, 1.0])
charges = default_charge_conservation(
partial_point_charges,
total_charge,
atomic_subsystem_indices,
)

assert torch.allclose(
torch.zeros_like(total_charge).scatter_add_(
0, atomic_subsystem_indices, charges
),
total_charge,
)

# ------------------------- #
# test case 2
partial_point_charges = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
total_charge = torch.tensor([-1.0, 2.0])
charges = default_charge_conservation(
partial_point_charges,
total_charge,
atomic_subsystem_indices,
)
assert torch.allclose(
torch.zeros_like(total_charge).scatter_add_(
0, atomic_subsystem_indices, charges
),
total_charge,
)

# ------------------------- #
# test case 3
partial_point_charges = torch.rand_like(
atomic_subsystem_indices, dtype=torch.float32
)
total_charge = torch.tensor([-1.0, 2.0])
charges = default_charge_conservation(
partial_point_charges,
total_charge,
atomic_subsystem_indices,
)
assert torch.allclose(
torch.zeros_like(total_charge).scatter_add_(
0, atomic_subsystem_indices, charges
),
total_charge,
)


def test_dense_layer():
from modelforge.potential.utils import DenseWithCustomDist
import torch
Expand Down

0 comments on commit e4892e6

Please sign in to comment.