Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify neighbor list to handle multiple cutoffs. #238

Merged
merged 6 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
This module contains the classes for the ANI2x neural network potential.
"""

from __future__ import annotations


from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Tuple, Type
from .models import BaseNetwork, CoreNetwork
Expand All @@ -12,10 +15,6 @@

from modelforge.utils.prop import SpeciesAEV

if TYPE_CHECKING:
from modelforge.dataset.dataset import NNPInput
from .models import PairListOutputs


def triu_index(num_species: int) -> torch.Tensor:
"""
Expand Down Expand Up @@ -106,7 +105,6 @@ def __init__(
angle_sections: int,
nr_of_supported_elements: int = 7,
):

super().__init__()
from modelforge.potential.utils import CosineAttenuationFunction

Expand Down Expand Up @@ -382,7 +380,6 @@ class ANIInteraction(nn.Module):
"""

def __init__(self, *, aev_dim: int, activation_function: Type[torch.nn.Module]):

super().__init__()
# define atomic neural network
atomic_neural_networks = self.intialize_atomic_neural_network(
Expand Down Expand Up @@ -561,7 +558,7 @@ def __init__(
self.register_buffer("lookup_tensor", lookup_tensor)

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: "PairListOutputs"
self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs]
) -> AniNeuralNetworkData:
"""
Prepare the model-specific input data for the ANI2x model.
Expand All @@ -570,8 +567,8 @@ def _model_specific_input_preparation(
----------
data : NNPInput
The input data for the model.
pairlist_output : PairListOutputs
The pairlist output.
pairlist_output : Dict[str,PairListOutputs]
The output from the pairlist.

Returns
-------
Expand All @@ -580,6 +577,11 @@ def _model_specific_input_preparation(
"""
number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_data = AniNeuralNetworkData(
pair_indices=pairlist_output.pair_indices,
d_ij=pairlist_output.d_ij,
Expand Down Expand Up @@ -668,7 +670,6 @@ def __init__(
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
) -> None:

from modelforge.utils.units import _convert_str_to_unit

self.only_unique_pairs = True # NOTE: need to be set before super().__init__
Expand Down
79 changes: 58 additions & 21 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,24 +311,31 @@ def forward(

class Neighborlist(Pairlist):
"""
Manage neighbor list calculations with a specified cutoff distance.
Manage neighbor list calculations with a specified cutoff distance(s).

This class extends Pairlist to consider a cutoff distance for neighbor calculations.
"""

def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = False):
def __init__(
self, cutoffs: Dict[str, unit.Quantity], only_unique_pairs: bool = False
):
"""
Initialize the Neighborlist with a specific cutoff distance.

Parameters
----------
cutoff : unit.Quantity
Cutoff distance for neighbor calculations.
cutoffs : Dict[str, unit.Quantity]
Cutoff distances for neighbor calculations.
only_unique_pairs : bool, optional
If True, only unique pairs are returned (default is False).
"""
super().__init__(only_unique_pairs=only_unique_pairs)
self.register_buffer("cutoff", torch.tensor(cutoff.to(unit.nanometer).m))

# self.register_buffer("cutoff", torch.tensor(cutoff.to(unit.nanometer).m))
self.register_buffer(
"cutoffs", torch.tensor([c.to(unit.nanometer).m for c in cutoffs.values()])
)
self.labels = list(cutoffs.keys())

def forward(
self,
Expand Down Expand Up @@ -364,16 +371,20 @@ def forward(
r_ij = self.calculate_r_ij(pair_indices, positions)
d_ij = self.calculate_d_ij(r_ij)

# Find pairs within the cutoff
in_cutoff = (d_ij <= self.cutoff).squeeze()
# Get the atom indices within the cutoff
pair_indices_within_cutoff = pair_indices[:, in_cutoff]
interacting_outputs = {}
for cutoff, label in zip(self.cutoffs, self.labels):
# Find pairs within the cutoff
in_cutoff = (d_ij <= cutoff).squeeze()
# Get the atom indices within the cutoff
pair_indices_within_cutoff = pair_indices[:, in_cutoff]

interacting_outputs[label] = PairListOutputs(
pair_indices=pair_indices_within_cutoff,
d_ij=d_ij[in_cutoff],
r_ij=r_ij[in_cutoff],
)

return PairListOutputs(
pair_indices=pair_indices_within_cutoff,
d_ij=d_ij[in_cutoff],
r_ij=r_ij[in_cutoff],
)
return interacting_outputs


from typing import Callable, Literal, Optional, Union
Expand Down Expand Up @@ -666,14 +677,16 @@ class ComputeInteractingAtomPairs(torch.nn.Module):
distances (d_ij), and displacement vectors (r_ij) for molecular simulations.
"""

def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True):
def __init__(
self, cutoffs: Dict[str, unit.Quantity], only_unique_pairs: bool = True
):
"""
Initialize the ComputeInteractingAtomPairs module.

Parameters
----------
cutoff : unit.Quantity
The cutoff distance for neighbor list calculations.
cutoffs : Dict[str, unit.Quantity]
The cutoff distance(s) for neighbor list calculations.
only_unique_pairs : bool, optional
Whether to only use unique pairs in the pair list calculation, by
default True. This should be set to True for all message passing
Expand All @@ -684,7 +697,7 @@ def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True):
from .models import Neighborlist

self.only_unique_pairs = only_unique_pairs
self.calculate_distances_and_pairlist = Neighborlist(cutoff, only_unique_pairs)
self.calculate_distances_and_pairlist = Neighborlist(cutoffs, only_unique_pairs)

def prepare_inputs(self, data: Union[NNPInput, NamedTuple]):
"""
Expand All @@ -703,7 +716,7 @@ def prepare_inputs(self, data: Union[NNPInput, NamedTuple]):
Returns
-------
PairListOutputs
A namedtuple containing the pair indices, Euclidean distances
A Dict for each cutoff type, where each entry is a namedtuple containing the pair indices, Euclidean distances
(d_ij), and displacement vectors (r_ij).
"""
# ---------------------------
Expand Down Expand Up @@ -736,6 +749,7 @@ def prepare_inputs(self, data: Union[NNPInput, NamedTuple]):
pair_indices=pair_list.to(torch.int64),
)

# this will return a Dict of the PairListOutputs for each cutoff we specify
return pairlist_output

def _input_checks(self, data: Union[NNPInput, NamedTuple]):
Expand Down Expand Up @@ -994,6 +1008,8 @@ def __init__(
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]],
maximum_interaction_radius: unit.Quantity,
maximum_dispersion_interaction_radius: Optional[unit.Quantity] = None,
maximum_coulomb_interaction_radius: Optional[unit.Quantity] = None,
potential_seed: Optional[int] = None,
):
"""
Expand All @@ -1006,7 +1022,11 @@ def __init__(
dataset_statistic : Optional[Dict[str, float]]
Dataset statistics for normalization.
maximum_interaction_radius : unit.Quantity
cutoff radius.
cutoff radius for local interactions
maximum_dispersion_interaction_radius : unit.Quantity, optional
cutoff radius for dispersion interactions.
maximum_coulomb_interaction_radius : unit.Quantity, optional
cutoff radius for Coulomb interactions.
potential_seed : Optional[int], optional
Value used for torch.manual_seed, by default None.
"""
Expand Down Expand Up @@ -1040,8 +1060,25 @@ def __init__(
raise RuntimeError(
"The only_unique_pairs attribute is not set in the child class. Please set it to True or False before calling super().__init__."
)

# to handle multiple cutoffs, we will create a dictionary with the cutoffs
# the dictionary will make it more transparent which PairListOutputs belong to which cutoff

cutoffs = {}
cutoffs["maximum_interaction_radius"] = _convert_str_to_unit(
maximum_interaction_radius
)
if maximum_dispersion_interaction_radius is not None:
cutoffs["maximum_dispersion_interaction_radius"] = _convert_str_to_unit(
maximum_dispersion_interaction_radius
)
if maximum_coulomb_interaction_radius is not None:
cutoffs["maximum_coulomb_interaction_radius"] = _convert_str_to_unit(
maximum_coulomb_interaction_radius
)

self.compute_interacting_pairs = ComputeInteractingAtomPairs(
cutoff=_convert_str_to_unit(maximum_interaction_radius),
cutoffs=cutoffs,
only_unique_pairs=self.only_unique_pairs,
)

Expand Down
11 changes: 8 additions & 3 deletions modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
)

def _model_specific_input_preparation(
self, data: NNPInput, pairlist_output: PairListOutputs
self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs]
) -> PaiNNNeuralNetworkData:
"""
Prepare the model-specific input for the PaiNN network.
Expand All @@ -147,8 +147,8 @@ def _model_specific_input_preparation(
----------
data : NNPInput
The input data.
pairlist_output : PairListOutputs
The pairlist output.
pairlist_output : dict[str, PairListOutputs]
The output from the pairlist.

Returns
-------
Expand All @@ -159,6 +159,11 @@ def _model_specific_input_preparation(

number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_input = PaiNNNeuralNetworkData(
pair_indices=pairlist_output.pair_indices,
d_ij=pairlist_output.d_ij,
Expand Down
9 changes: 7 additions & 2 deletions modelforge/potential/physnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
self.atomic_shift = nn.Parameter(torch.zeros(maximum_atomic_number, 2))

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: "PairListOutputs"
self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs]
) -> PhysNetNeuralNetworkData:
"""
Prepare model-specific input data.
Expand All @@ -495,7 +495,7 @@ def _model_specific_input_preparation(
----------
data : NNPInput
Input data containing atomic information.
pairlist_output : PairListOutputs
pairlist_output : Dict[str, PairListOutputs]
Output from the pairlist calculation.

Returns
Expand All @@ -505,6 +505,11 @@ def _model_specific_input_preparation(
"""
number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_input = PhysNetNeuralNetworkData(
pair_indices=pairlist_output.pair_indices,
d_ij=pairlist_output.d_ij,
Expand Down
11 changes: 8 additions & 3 deletions modelforge/potential/sake.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
)

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: "PairListOutputs"
self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs]
) -> SAKENeuralNetworkInput:
"""
Prepare the model-specific input.
Expand All @@ -147,8 +147,8 @@ def _model_specific_input_preparation(
----------
data : NNPInput
Input data.
pairlist_output : PairListOutputs
Pairlist output.
pairlist_output : Dict[str,PairListOutputs]
Pairlist output(s)

Returns
-------
Expand All @@ -159,6 +159,11 @@ def _model_specific_input_preparation(

number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_input = SAKENeuralNetworkInput(
pair_indices=pairlist_output.pair_indices,
number_of_atoms=number_of_atoms,
Expand Down
13 changes: 9 additions & 4 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
number_of_radial_basis_functions,
featurization_config=featurization_config,
)
# Intialize interaction blocks
# Initialize interaction blocks
if shared_interactions:
self.interaction_modules = nn.ModuleList(
[
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
)

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: PairListOutputs
self, data: "NNPInput", pairlist_output: Dict[str, PairListOutputs]
) -> SchnetNeuralNetworkData:
"""
Prepare the input data for the SchNet model.
Expand All @@ -143,8 +143,8 @@ def _model_specific_input_preparation(
----------
data : NNPInput
The input data for the model.
pairlist_output : PairListOutputs
The pairlist output.
pairlist_output : Dict[str, PairListOutputs]
The pairlist output(s).

Returns
-------
Expand All @@ -153,6 +153,11 @@ def _model_specific_input_preparation(
"""
number_of_atoms = data.atomic_numbers.shape[0]

# Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter
# e.g. "maximum_interaction_radius"

pairlist_output = pairlist_output["maximum_interaction_radius"]

nnp_input = SchnetNeuralNetworkData(
pair_indices=pairlist_output.pair_indices,
d_ij=pairlist_output.d_ij,
Expand Down
Loading
Loading