diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index 2e84ff13..67019c77 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -2,6 +2,9 @@ This module contains the classes for the ANI2x neural network potential. """ +from __future__ import annotations + + from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Tuple, Type from .models import BaseNetwork, CoreNetwork @@ -12,10 +15,6 @@ from modelforge.utils.prop import SpeciesAEV -if TYPE_CHECKING: - from modelforge.dataset.dataset import NNPInput - from .models import PairListOutputs - def triu_index(num_species: int) -> torch.Tensor: """ @@ -106,7 +105,6 @@ def __init__( angle_sections: int, nr_of_supported_elements: int = 7, ): - super().__init__() from modelforge.potential.utils import CosineAttenuationFunction @@ -382,7 +380,6 @@ class ANIInteraction(nn.Module): """ def __init__(self, *, aev_dim: int, activation_function: Type[torch.nn.Module]): - super().__init__() # define atomic neural network atomic_neural_networks = self.intialize_atomic_neural_network( @@ -561,7 +558,7 @@ def __init__( self.register_buffer("lookup_tensor", lookup_tensor) def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" + self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs] ) -> AniNeuralNetworkData: """ Prepare the model-specific input data for the ANI2x model. @@ -570,8 +567,8 @@ def _model_specific_input_preparation( ---------- data : NNPInput The input data for the model. - pairlist_output : PairListOutputs - The pairlist output. + pairlist_output : Dict[str,PairListOutputs] + The output from the pairlist. Returns ------- @@ -580,6 +577,11 @@ def _model_specific_input_preparation( """ number_of_atoms = data.atomic_numbers.shape[0] + # Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter + # e.g. "maximum_interaction_radius" + + pairlist_output = pairlist_output["maximum_interaction_radius"] + nnp_data = AniNeuralNetworkData( pair_indices=pairlist_output.pair_indices, d_ij=pairlist_output.d_ij, @@ -668,7 +670,6 @@ def __init__( dataset_statistic: Optional[Dict[str, float]] = None, potential_seed: Optional[int] = None, ) -> None: - from modelforge.utils.units import _convert_str_to_unit self.only_unique_pairs = True # NOTE: need to be set before super().__init__ diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 5285fa4c..dd6a5ae1 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -311,24 +311,31 @@ def forward( class Neighborlist(Pairlist): """ - Manage neighbor list calculations with a specified cutoff distance. + Manage neighbor list calculations with a specified cutoff distance(s). This class extends Pairlist to consider a cutoff distance for neighbor calculations. """ - def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = False): + def __init__( + self, cutoffs: Dict[str, unit.Quantity], only_unique_pairs: bool = False + ): """ Initialize the Neighborlist with a specific cutoff distance. Parameters ---------- - cutoff : unit.Quantity - Cutoff distance for neighbor calculations. + cutoffs : Dict[str, unit.Quantity] + Cutoff distances for neighbor calculations. only_unique_pairs : bool, optional If True, only unique pairs are returned (default is False). """ super().__init__(only_unique_pairs=only_unique_pairs) - self.register_buffer("cutoff", torch.tensor(cutoff.to(unit.nanometer).m)) + + # self.register_buffer("cutoff", torch.tensor(cutoff.to(unit.nanometer).m)) + self.register_buffer( + "cutoffs", torch.tensor([c.to(unit.nanometer).m for c in cutoffs.values()]) + ) + self.labels = list(cutoffs.keys()) def forward( self, @@ -364,16 +371,20 @@ def forward( r_ij = self.calculate_r_ij(pair_indices, positions) d_ij = self.calculate_d_ij(r_ij) - # Find pairs within the cutoff - in_cutoff = (d_ij <= self.cutoff).squeeze() - # Get the atom indices within the cutoff - pair_indices_within_cutoff = pair_indices[:, in_cutoff] + interacting_outputs = {} + for cutoff, label in zip(self.cutoffs, self.labels): + # Find pairs within the cutoff + in_cutoff = (d_ij <= cutoff).squeeze() + # Get the atom indices within the cutoff + pair_indices_within_cutoff = pair_indices[:, in_cutoff] + + interacting_outputs[label] = PairListOutputs( + pair_indices=pair_indices_within_cutoff, + d_ij=d_ij[in_cutoff], + r_ij=r_ij[in_cutoff], + ) - return PairListOutputs( - pair_indices=pair_indices_within_cutoff, - d_ij=d_ij[in_cutoff], - r_ij=r_ij[in_cutoff], - ) + return interacting_outputs from typing import Callable, Literal, Optional, Union @@ -666,14 +677,16 @@ class ComputeInteractingAtomPairs(torch.nn.Module): distances (d_ij), and displacement vectors (r_ij) for molecular simulations. """ - def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True): + def __init__( + self, cutoffs: Dict[str, unit.Quantity], only_unique_pairs: bool = True + ): """ Initialize the ComputeInteractingAtomPairs module. Parameters ---------- - cutoff : unit.Quantity - The cutoff distance for neighbor list calculations. + cutoffs : Dict[str, unit.Quantity] + The cutoff distance(s) for neighbor list calculations. only_unique_pairs : bool, optional Whether to only use unique pairs in the pair list calculation, by default True. This should be set to True for all message passing @@ -684,7 +697,7 @@ def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True): from .models import Neighborlist self.only_unique_pairs = only_unique_pairs - self.calculate_distances_and_pairlist = Neighborlist(cutoff, only_unique_pairs) + self.calculate_distances_and_pairlist = Neighborlist(cutoffs, only_unique_pairs) def prepare_inputs(self, data: Union[NNPInput, NamedTuple]): """ @@ -703,7 +716,7 @@ def prepare_inputs(self, data: Union[NNPInput, NamedTuple]): Returns ------- PairListOutputs - A namedtuple containing the pair indices, Euclidean distances + A Dict for each cutoff type, where each entry is a namedtuple containing the pair indices, Euclidean distances (d_ij), and displacement vectors (r_ij). """ # --------------------------- @@ -736,6 +749,7 @@ def prepare_inputs(self, data: Union[NNPInput, NamedTuple]): pair_indices=pair_list.to(torch.int64), ) + # this will return a Dict of the PairListOutputs for each cutoff we specify return pairlist_output def _input_checks(self, data: Union[NNPInput, NamedTuple]): @@ -994,6 +1008,8 @@ def __init__( postprocessing_parameter: Dict[str, Dict[str, bool]], dataset_statistic: Optional[Dict[str, float]], maximum_interaction_radius: unit.Quantity, + maximum_dispersion_interaction_radius: Optional[unit.Quantity] = None, + maximum_coulomb_interaction_radius: Optional[unit.Quantity] = None, potential_seed: Optional[int] = None, ): """ @@ -1006,7 +1022,11 @@ def __init__( dataset_statistic : Optional[Dict[str, float]] Dataset statistics for normalization. maximum_interaction_radius : unit.Quantity - cutoff radius. + cutoff radius for local interactions + maximum_dispersion_interaction_radius : unit.Quantity, optional + cutoff radius for dispersion interactions. + maximum_coulomb_interaction_radius : unit.Quantity, optional + cutoff radius for Coulomb interactions. potential_seed : Optional[int], optional Value used for torch.manual_seed, by default None. """ @@ -1040,8 +1060,25 @@ def __init__( raise RuntimeError( "The only_unique_pairs attribute is not set in the child class. Please set it to True or False before calling super().__init__." ) + + # to handle multiple cutoffs, we will create a dictionary with the cutoffs + # the dictionary will make it more transparent which PairListOutputs belong to which cutoff + + cutoffs = {} + cutoffs["maximum_interaction_radius"] = _convert_str_to_unit( + maximum_interaction_radius + ) + if maximum_dispersion_interaction_radius is not None: + cutoffs["maximum_dispersion_interaction_radius"] = _convert_str_to_unit( + maximum_dispersion_interaction_radius + ) + if maximum_coulomb_interaction_radius is not None: + cutoffs["maximum_coulomb_interaction_radius"] = _convert_str_to_unit( + maximum_coulomb_interaction_radius + ) + self.compute_interacting_pairs = ComputeInteractingAtomPairs( - cutoff=_convert_str_to_unit(maximum_interaction_radius), + cutoffs=cutoffs, only_unique_pairs=self.only_unique_pairs, ) diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index c90a5f1a..6d297530 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -138,7 +138,7 @@ def __init__( ) def _model_specific_input_preparation( - self, data: NNPInput, pairlist_output: PairListOutputs + self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs] ) -> PaiNNNeuralNetworkData: """ Prepare the model-specific input for the PaiNN network. @@ -147,8 +147,8 @@ def _model_specific_input_preparation( ---------- data : NNPInput The input data. - pairlist_output : PairListOutputs - The pairlist output. + pairlist_output : dict[str, PairListOutputs] + The output from the pairlist. Returns ------- @@ -159,6 +159,11 @@ def _model_specific_input_preparation( number_of_atoms = data.atomic_numbers.shape[0] + # Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter + # e.g. "maximum_interaction_radius" + + pairlist_output = pairlist_output["maximum_interaction_radius"] + nnp_input = PaiNNNeuralNetworkData( pair_indices=pairlist_output.pair_indices, d_ij=pairlist_output.d_ij, diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index 20bc9ac4..ee20721d 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -486,7 +486,7 @@ def __init__( self.atomic_shift = nn.Parameter(torch.zeros(maximum_atomic_number, 2)) def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" + self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs] ) -> PhysNetNeuralNetworkData: """ Prepare model-specific input data. @@ -495,7 +495,7 @@ def _model_specific_input_preparation( ---------- data : NNPInput Input data containing atomic information. - pairlist_output : PairListOutputs + pairlist_output : Dict[str, PairListOutputs] Output from the pairlist calculation. Returns @@ -505,6 +505,11 @@ def _model_specific_input_preparation( """ number_of_atoms = data.atomic_numbers.shape[0] + # Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter + # e.g. "maximum_interaction_radius" + + pairlist_output = pairlist_output["maximum_interaction_radius"] + nnp_input = PhysNetNeuralNetworkData( pair_indices=pairlist_output.pair_indices, d_ij=pairlist_output.d_ij, diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index bed45f2e..49f5a48e 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -138,7 +138,7 @@ def __init__( ) def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" + self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs] ) -> SAKENeuralNetworkInput: """ Prepare the model-specific input. @@ -147,8 +147,8 @@ def _model_specific_input_preparation( ---------- data : NNPInput Input data. - pairlist_output : PairListOutputs - Pairlist output. + pairlist_output : Dict[str,PairListOutputs] + Pairlist output(s) Returns ------- @@ -159,6 +159,11 @@ def _model_specific_input_preparation( number_of_atoms = data.atomic_numbers.shape[0] + # Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter + # e.g. "maximum_interaction_radius" + + pairlist_output = pairlist_output["maximum_interaction_radius"] + nnp_input = SAKENeuralNetworkInput( pair_indices=pairlist_output.pair_indices, number_of_atoms=number_of_atoms, diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index 2640b456..9155832f 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -93,7 +93,7 @@ def __init__( number_of_radial_basis_functions, featurization_config=featurization_config, ) - # Intialize interaction blocks + # Initialize interaction blocks if shared_interactions: self.interaction_modules = nn.ModuleList( [ @@ -134,7 +134,7 @@ def __init__( ) def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: PairListOutputs + self, data: "NNPInput", pairlist_output: Dict[str, PairListOutputs] ) -> SchnetNeuralNetworkData: """ Prepare the input data for the SchNet model. @@ -143,8 +143,8 @@ def _model_specific_input_preparation( ---------- data : NNPInput The input data for the model. - pairlist_output : PairListOutputs - The pairlist output. + pairlist_output : Dict[str, PairListOutputs] + The pairlist output(s). Returns ------- @@ -153,6 +153,11 @@ def _model_specific_input_preparation( """ number_of_atoms = data.atomic_numbers.shape[0] + # Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter + # e.g. "maximum_interaction_radius" + + pairlist_output = pairlist_output["maximum_interaction_radius"] + nnp_input = SchnetNeuralNetworkData( pair_indices=pairlist_output.pair_indices, d_ij=pairlist_output.d_ij, diff --git a/modelforge/potential/tensornet.py b/modelforge/potential/tensornet.py index fb03985d..3e9079ad 100644 --- a/modelforge/potential/tensornet.py +++ b/modelforge/potential/tensornet.py @@ -367,7 +367,7 @@ def compute_properties( } def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" + self, data: NNPInput, pairlist_output: Dict[str, PairListOutputs] ) -> TensorNetNeuralNetworkData: """ Prepare the input data for the TensorNet model. @@ -376,8 +376,8 @@ def _model_specific_input_preparation( ---------- data : NNPInput The input data for the model. - pairlist_output : PairListOutputs - The pairlist output. + pairlist_output : Dict[str, PairListOutputs] + The pairlist output(s) Returns ------- @@ -386,6 +386,11 @@ def _model_specific_input_preparation( """ number_of_atoms = data.atomic_numbers.shape[0] + # Note, pairlist_output is a Dict where the key corresponds to the name of the cutoff parameter + # e.g. "maximum_interaction_radius" + + pairlist_output = pairlist_output["maximum_interaction_radius"] + nnpdata = TensorNetNeuralNetworkData( pair_indices=pairlist_output.pair_indices, d_ij=pairlist_output.d_ij, diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 9709c6b0..1faecbf6 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -660,8 +660,8 @@ def test_pairlist(): from openff.units import unit cutoff = 5.0 * unit.nanometer # no relevant cutoff - pairlist = Neighborlist(cutoff, only_unique_pairs=True) - r = pairlist(positions, atomic_subsystem_indices) + pairlist = Neighborlist({"cutoff1": cutoff}, only_unique_pairs=True) + r = pairlist(positions, atomic_subsystem_indices)["cutoff1"] pair_indices = r.pair_indices # pairlist describes the pairs of interacting atoms within a batch @@ -690,8 +690,8 @@ def test_pairlist(): # test with cutoff cutoff = 2.0 * unit.nanometer - pairlist = Neighborlist(cutoff, only_unique_pairs=True) - r = pairlist(positions, atomic_subsystem_indices) + pairlist = Neighborlist({"cutoff1": cutoff}, only_unique_pairs=True) + r = pairlist(positions, atomic_subsystem_indices)["cutoff1"] pair_indices = r.pair_indices assert torch.equal(pair_indices, torch.tensor([[0, 1, 3, 4], [1, 2, 4, 5]])) @@ -714,8 +714,8 @@ def test_pairlist(): # test with complete pairlist cutoff = 2.0 * unit.nanometer - pairlist = Neighborlist(cutoff, only_unique_pairs=False) - r = pairlist(positions, atomic_subsystem_indices) + pairlist = Neighborlist({"cutoff1": cutoff}, only_unique_pairs=False) + r = pairlist(positions, atomic_subsystem_indices)["cutoff1"] pair_indices = r.pair_indices print(pair_indices, flush=True) @@ -726,11 +726,13 @@ def test_pairlist(): # make sure that Pairlist and Neighborlist behave the same for large cutoffs cutoff = 10.0 * unit.nanometer only_unique_pairs = False - neighborlist = Neighborlist(cutoff, only_unique_pairs=only_unique_pairs) + neighborlist = Neighborlist( + {"cutoff1": cutoff}, only_unique_pairs=only_unique_pairs + ) pairlist = Pairlist(only_unique_pairs=only_unique_pairs) r = pairlist(positions, atomic_subsystem_indices) pair_indices = r.pair_indices - r = neighborlist(positions, atomic_subsystem_indices) + r = neighborlist(positions, atomic_subsystem_indices)["cutoff1"] neighbor_indices = r.pair_indices assert torch.equal(pair_indices, neighbor_indices) @@ -738,11 +740,13 @@ def test_pairlist(): # make sure that they are the same also for non-redundant pairs cutoff = 10.0 * unit.nanometer only_unique_pairs = True - neighborlist = Neighborlist(cutoff, only_unique_pairs=only_unique_pairs) + neighborlist = Neighborlist( + {"cutoff1": cutoff}, only_unique_pairs=only_unique_pairs + ) pairlist = Pairlist(only_unique_pairs=only_unique_pairs) r = pairlist(positions, atomic_subsystem_indices) pair_indices = r.pair_indices - r = neighborlist(positions, atomic_subsystem_indices) + r = neighborlist(positions, atomic_subsystem_indices)["cutoff1"] neighbor_indices = r.pair_indices assert torch.equal(pair_indices, neighbor_indices) @@ -750,16 +754,59 @@ def test_pairlist(): # this should fail cutoff = 2.0 * unit.nanometer only_unique_pairs = True - neighborlist = Neighborlist(cutoff, only_unique_pairs=only_unique_pairs) + neighborlist = Neighborlist( + {"cutoff1": cutoff}, only_unique_pairs=only_unique_pairs + ) pairlist = Pairlist(only_unique_pairs=only_unique_pairs) r = pairlist(positions, atomic_subsystem_indices) pair_indices = r.pair_indices - r = neighborlist(positions, atomic_subsystem_indices) + r = neighborlist(positions, atomic_subsystem_indices)["cutoff1"] neighbor_indices = r.pair_indices assert not pair_indices.shape == neighbor_indices.shape +def test_multiple_neighborlists(): + from modelforge.potential.models import Pairlist, Neighborlist + import torch + from openff.units import unit + + atomic_subsystem_indices = torch.tensor([0, 0, 0, 0, 0]) + + positions = torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [3.0, 0.0, 0.0], + [4.0, 0.0, 0.0], + ] + ) + + cutoff_short = 1.5 * unit.nanometer + cutoff_medium = 2.5 * unit.nanometer + cutoff_long = 3.5 * unit.nanometer + pairlist = Neighborlist( + {"short": cutoff_short, "medium": cutoff_medium, "long": cutoff_long}, + only_unique_pairs=True, + ) + r = pairlist(positions, atomic_subsystem_indices) + + assert torch.equal( + r["short"].pair_indices, torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]) + ) + + assert torch.equal( + r["medium"].pair_indices, + torch.tensor([[0, 0, 1, 1, 2, 2, 3], [1, 2, 2, 3, 3, 4, 4]]), + ) + + assert torch.equal( + r["long"].pair_indices, + torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3], [1, 2, 3, 2, 3, 4, 3, 4, 4]]), + ) + + def test_pairlist_precomputation(): from modelforge.potential.models import Pairlist import torch @@ -1113,7 +1160,7 @@ def test_pairlist_calculate_r_ij_and_d_ij(): # Create Pairlist instance # --------------------------- # # Only unique pairs - pairlist = Neighborlist(cutoff, only_unique_pairs=True) + pairlist = Neighborlist({"cutoff_1": cutoff}, only_unique_pairs=True) pair_indices = pairlist.enumerate_all_pairs(atomic_subsystem_indices) # Calculate r_ij and d_ij @@ -1135,7 +1182,7 @@ def test_pairlist_calculate_r_ij_and_d_ij(): # --------------------------- # # ALL pairs - pairlist = Neighborlist(cutoff, only_unique_pairs=False) + pairlist = Neighborlist({"cutoff_1": cutoff}, only_unique_pairs=False) pair_indices = pairlist.enumerate_all_pairs(atomic_subsystem_indices) # Calculate r_ij and d_ij diff --git a/modelforge/tests/test_tensornet.py b/modelforge/tests/test_tensornet.py index c7a9fd4e..92ed271c 100644 --- a/modelforge/tests/test_tensornet.py +++ b/modelforge/tests/test_tensornet.py @@ -106,7 +106,9 @@ def test_input(): ], ) tensornet.compute_interacting_pairs._input_checks(mf_input) - pairlist_output = tensornet.compute_interacting_pairs.prepare_inputs(mf_input) + pairlist_output = tensornet.compute_interacting_pairs.prepare_inputs(mf_input)[ + "maximum_interaction_radius" + ] # torchmd-net TensorNet if reference_data: