Skip to content

Commit

Permalink
Merge pull request #238 from chrisiacovella/multi_nlists
Browse files Browse the repository at this point in the history
Modify neighbor list to handle multiple cutoffs.
  • Loading branch information
chrisiacovella authored Aug 27, 2024
2 parents 7f31bab + d93c82f commit 532cda7
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 61 deletions.
21 changes: 11 additions & 10 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
This module contains the classes for the ANI2x neural network potential.
"""

from __future__ import annotations


from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Tuple, Type
from .models import BaseNetwork, CoreNetwork
Expand All @@ -12,10 +15,6 @@

from modelforge.utils.prop import SpeciesAEV

if TYPE_CHECKING:
from modelforge.dataset.dataset import NNPInput
from .models import PairListOutputs


def triu_index(num_species: int) -> torch.Tensor:
"""
Expand Down Expand Up @@ -106,7 +105,6 @@ def __init__(
angle_sections: int,
nr_of_supported_elements: int = 7,
):

super().__init__()
from modelforge.potential.utils import CosineAttenuationFunction

Expand Down Expand Up @@ -382,7 +380,6 @@ class ANIInteraction(nn.Module):
"""

def __init__(self, *, aev_dim: int, activation_function: Type[torch.nn.Module]):

super().__init__()
# define atomic neural network
atomic_neural_networks = self.intialize_atomic_neural_network(
Expand Down Expand Up @@ -561,7 +558,7 @@ def __init__(
self.register_buffer("lookup_tensor", lookup_tensor)

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: "PairListOutputs"
self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs]
) -> AniNeuralNetworkData:
"""
Prepare the model-specific input data for the ANI2x model.
Expand All @@ -570,8 +567,8 @@ def _model_specific_input_preparation(
----------
data : NNPInput
The input data for the model.
pairlist_output : PairListOutputs
The pairlist output.
pairlist_output : Dict[str,PairListOutputs]
The output from the pairlist.
Returns
-------
Expand All @@ -580,6 +577,11 @@ def _model_specific_input_preparation(
"""
number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_data = AniNeuralNetworkData(
pair_indices=pairlist_output.pair_indices,
d_ij=pairlist_output.d_ij,
Expand Down Expand Up @@ -668,7 +670,6 @@ def __init__(
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
) -> None:

from modelforge.utils.units import _convert_str_to_unit

self.only_unique_pairs = True # NOTE: need to be set before super().__init__
Expand Down
79 changes: 58 additions & 21 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,24 +311,31 @@ def forward(

class Neighborlist(Pairlist):
"""
Manage neighbor list calculations with a specified cutoff distance.
Manage neighbor list calculations with a specified cutoff distance(s).
This class extends Pairlist to consider a cutoff distance for neighbor calculations.
"""

def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = False):
def __init__(
self, cutoffs: Dict[str, unit.Quantity], only_unique_pairs: bool = False
):
"""
Initialize the Neighborlist with a specific cutoff distance.
Parameters
----------
cutoff : unit.Quantity
Cutoff distance for neighbor calculations.
cutoffs : Dict[str, unit.Quantity]
Cutoff distances for neighbor calculations.
only_unique_pairs : bool, optional
If True, only unique pairs are returned (default is False).
"""
super().__init__(only_unique_pairs=only_unique_pairs)
self.register_buffer("cutoff", torch.tensor(cutoff.to(unit.nanometer).m))

# self.register_buffer("cutoff", torch.tensor(cutoff.to(unit.nanometer).m))
self.register_buffer(
"cutoffs", torch.tensor([c.to(unit.nanometer).m for c in cutoffs.values()])
)
self.labels = list(cutoffs.keys())

def forward(
self,
Expand Down Expand Up @@ -364,16 +371,20 @@ def forward(
r_ij = self.calculate_r_ij(pair_indices, positions)
d_ij = self.calculate_d_ij(r_ij)

# Find pairs within the cutoff
in_cutoff = (d_ij <= self.cutoff).squeeze()
# Get the atom indices within the cutoff
pair_indices_within_cutoff = pair_indices[:, in_cutoff]
interacting_outputs = {}
for cutoff, label in zip(self.cutoffs, self.labels):
# Find pairs within the cutoff
in_cutoff = (d_ij <= cutoff).squeeze()
# Get the atom indices within the cutoff
pair_indices_within_cutoff = pair_indices[:, in_cutoff]

interacting_outputs[label] = PairListOutputs(
pair_indices=pair_indices_within_cutoff,
d_ij=d_ij[in_cutoff],
r_ij=r_ij[in_cutoff],
)

return PairListOutputs(
pair_indices=pair_indices_within_cutoff,
d_ij=d_ij[in_cutoff],
r_ij=r_ij[in_cutoff],
)
return interacting_outputs


from typing import Callable, Literal, Optional, Union
Expand Down Expand Up @@ -666,14 +677,16 @@ class ComputeInteractingAtomPairs(torch.nn.Module):
distances (d_ij), and displacement vectors (r_ij) for molecular simulations.
"""

def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True):
def __init__(
self, cutoffs: Dict[str, unit.Quantity], only_unique_pairs: bool = True
):
"""
Initialize the ComputeInteractingAtomPairs module.
Parameters
----------
cutoff : unit.Quantity
The cutoff distance for neighbor list calculations.
cutoffs : Dict[str, unit.Quantity]
The cutoff distance(s) for neighbor list calculations.
only_unique_pairs : bool, optional
Whether to only use unique pairs in the pair list calculation, by
default True. This should be set to True for all message passing
Expand All @@ -684,7 +697,7 @@ def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True):
from .models import Neighborlist

self.only_unique_pairs = only_unique_pairs
self.calculate_distances_and_pairlist = Neighborlist(cutoff, only_unique_pairs)
self.calculate_distances_and_pairlist = Neighborlist(cutoffs, only_unique_pairs)

def prepare_inputs(self, data: Union[NNPInput, NamedTuple]):
"""
Expand All @@ -703,7 +716,7 @@ def prepare_inputs(self, data: Union[NNPInput, NamedTuple]):
Returns
-------
PairListOutputs
A namedtuple containing the pair indices, Euclidean distances
A Dict for each cutoff type, where each entry is a namedtuple containing the pair indices, Euclidean distances
(d_ij), and displacement vectors (r_ij).
"""
# ---------------------------
Expand Down Expand Up @@ -736,6 +749,7 @@ def prepare_inputs(self, data: Union[NNPInput, NamedTuple]):
pair_indices=pair_list.to(torch.int64),
)

# this will return a Dict of the PairListOutputs for each cutoff we specify
return pairlist_output

def _input_checks(self, data: Union[NNPInput, NamedTuple]):
Expand Down Expand Up @@ -994,6 +1008,8 @@ def __init__(
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]],
maximum_interaction_radius: unit.Quantity,
maximum_dispersion_interaction_radius: Optional[unit.Quantity] = None,
maximum_coulomb_interaction_radius: Optional[unit.Quantity] = None,
potential_seed: Optional[int] = None,
):
"""
Expand All @@ -1006,7 +1022,11 @@ def __init__(
dataset_statistic : Optional[Dict[str, float]]
Dataset statistics for normalization.
maximum_interaction_radius : unit.Quantity
cutoff radius.
cutoff radius for local interactions
maximum_dispersion_interaction_radius : unit.Quantity, optional
cutoff radius for dispersion interactions.
maximum_coulomb_interaction_radius : unit.Quantity, optional
cutoff radius for Coulomb interactions.
potential_seed : Optional[int], optional
Value used for torch.manual_seed, by default None.
"""
Expand Down Expand Up @@ -1040,8 +1060,25 @@ def __init__(
raise RuntimeError(
"The only_unique_pairs attribute is not set in the child class. Please set it to True or False before calling super().__init__."
)

# to handle multiple cutoffs, we will create a dictionary with the cutoffs
# the dictionary will make it more transparent which PairListOutputs belong to which cutoff

cutoffs = {}
cutoffs["maximum_interaction_radius"] = _convert_str_to_unit(
maximum_interaction_radius
)
if maximum_dispersion_interaction_radius is not None:
cutoffs["maximum_dispersion_interaction_radius"] = _convert_str_to_unit(
maximum_dispersion_interaction_radius
)
if maximum_coulomb_interaction_radius is not None:
cutoffs["maximum_coulomb_interaction_radius"] = _convert_str_to_unit(
maximum_coulomb_interaction_radius
)

self.compute_interacting_pairs = ComputeInteractingAtomPairs(
cutoff=_convert_str_to_unit(maximum_interaction_radius),
cutoffs=cutoffs,
only_unique_pairs=self.only_unique_pairs,
)

Expand Down
11 changes: 8 additions & 3 deletions modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
)

def _model_specific_input_preparation(
self, data: NNPInput, pairlist_output: PairListOutputs
self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs]
) -> PaiNNNeuralNetworkData:
"""
Prepare the model-specific input for the PaiNN network.
Expand All @@ -147,8 +147,8 @@ def _model_specific_input_preparation(
----------
data : NNPInput
The input data.
pairlist_output : PairListOutputs
The pairlist output.
pairlist_output : dict[str, PairListOutputs]
The output from the pairlist.
Returns
-------
Expand All @@ -159,6 +159,11 @@ def _model_specific_input_preparation(

number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_input = PaiNNNeuralNetworkData(
pair_indices=pairlist_output.pair_indices,
d_ij=pairlist_output.d_ij,
Expand Down
9 changes: 7 additions & 2 deletions modelforge/potential/physnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
self.atomic_shift = nn.Parameter(torch.zeros(maximum_atomic_number, 2))

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: "PairListOutputs"
self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs]
) -> PhysNetNeuralNetworkData:
"""
Prepare model-specific input data.
Expand All @@ -495,7 +495,7 @@ def _model_specific_input_preparation(
----------
data : NNPInput
Input data containing atomic information.
pairlist_output : PairListOutputs
pairlist_output : Dict[str, PairListOutputs]
Output from the pairlist calculation.
Returns
Expand All @@ -505,6 +505,11 @@ def _model_specific_input_preparation(
"""
number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_input = PhysNetNeuralNetworkData(
pair_indices=pairlist_output.pair_indices,
d_ij=pairlist_output.d_ij,
Expand Down
11 changes: 8 additions & 3 deletions modelforge/potential/sake.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
)

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: "PairListOutputs"
self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs]
) -> SAKENeuralNetworkInput:
"""
Prepare the model-specific input.
Expand All @@ -147,8 +147,8 @@ def _model_specific_input_preparation(
----------
data : NNPInput
Input data.
pairlist_output : PairListOutputs
Pairlist output.
pairlist_output : Dict[str,PairListOutputs]
Pairlist output(s)
Returns
-------
Expand All @@ -159,6 +159,11 @@ def _model_specific_input_preparation(

number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_input = SAKENeuralNetworkInput(
pair_indices=pairlist_output.pair_indices,
number_of_atoms=number_of_atoms,
Expand Down
13 changes: 9 additions & 4 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
number_of_radial_basis_functions,
featurization_config=featurization_config,
)
# Intialize interaction blocks
# Initialize interaction blocks
if shared_interactions:
self.interaction_modules = nn.ModuleList(
[
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
)

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: PairListOutputs
self, data: "NNPInput", pairlist_output: Dict[str, PairListOutputs]
) -> SchnetNeuralNetworkData:
"""
Prepare the input data for the SchNet model.
Expand All @@ -143,8 +143,8 @@ def _model_specific_input_preparation(
----------
data : NNPInput
The input data for the model.
pairlist_output : PairListOutputs
The pairlist output.
pairlist_output : Dict[str, PairListOutputs]
The pairlist output(s).
Returns
-------
Expand All @@ -153,6 +153,11 @@ def _model_specific_input_preparation(
"""
number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_input = SchnetNeuralNetworkData(
pair_indices=pairlist_output.pair_indices,
d_ij=pairlist_output.d_ij,
Expand Down
Loading

0 comments on commit 532cda7

Please sign in to comment.