Skip to content

Commit

Permalink
multiple output layers for ANI2x
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Aug 16, 2024
1 parent ff1cd3d commit 5ecf6bc
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 25 deletions.
95 changes: 71 additions & 24 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Tuple, Type
from .models import BaseNetwork, CoreNetwork
from typing import Dict, Tuple, Type, List

import torch
from loguru import logger as log
from torch import nn

from modelforge.utils.prop import SpeciesAEV

if TYPE_CHECKING:
from modelforge.dataset.dataset import NNPInput
from .models import PairListOutputs
from .models import BaseNetwork, CoreNetwork


def triu_index(num_species: int) -> torch.Tensor:
Expand Down Expand Up @@ -369,6 +366,27 @@ def _preprocess_angular_aev(self, data: Dict[str, torch.Tensor]):
}


class MultiOutputHeadNetwork(nn.Module):

def __init__(self, shared_layers: nn.Sequential, output_dims: int):
super().__init__()
self.shared_layers = shared_layers
input_dim = shared_layers[
-2
].out_features # Get the output dim from the last shared layer

# Create a list of output heads
self.output_heads = nn.ModuleList(
[nn.Linear(input_dim, 1) for _ in range(output_dims)]
)

def forward(self, x):
x = self.shared_layers(x)
outputs = [head(x) for head in self.output_heads]
# Concatenate the outputs into a single tensor along the last dimension
return torch.cat(outputs, dim=1)


class ANIInteraction(nn.Module):
"""
Atomic neural network interaction module for ANI.
Expand All @@ -381,9 +399,16 @@ class ANIInteraction(nn.Module):
The activation function to use.
"""

def __init__(self, *, aev_dim: int, activation_function: Type[torch.nn.Module]):
def __init__(
self,
*,
aev_dim: int,
activation_function: Type[torch.nn.Module],
predicted_properties: List[Dict[str, str]],
):

super().__init__()
self.predicted_properties = predicted_properties
# define atomic neural network
atomic_neural_networks = self.intialize_atomic_neural_network(
aev_dim, activation_function
Expand Down Expand Up @@ -428,14 +453,19 @@ def create_network(layers):
nn.Sequential
The created neural network.
"""
network_layers = []
shared_network_layers = []
input_dim = aev_dim
for units in layers:
network_layers.append(nn.Linear(input_dim, units))
network_layers.append(activation_function)
shared_network_layers.append(nn.Linear(input_dim, units))
shared_network_layers.append(activation_function)
input_dim = units
network_layers.append(nn.Linear(input_dim, 1))
return nn.Sequential(*network_layers)

# Create a MultiOutputHeadNetwork with the specified output
# dimensions
shared_layers = nn.Sequential(*shared_network_layers)
return MultiOutputHeadNetwork(
shared_layers, output_dims=len(self.predicted_properties)
)

return {
element: create_network(layers)
Expand All @@ -450,7 +480,7 @@ def create_network(layers):
}.items()
}

def forward(self, input: Tuple[torch.Tensor, torch.Tensor]):
def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""
Forward pass to compute atomic energies from AEVs.
Expand All @@ -465,16 +495,23 @@ def forward(self, input: Tuple[torch.Tensor, torch.Tensor]):
The computed atomic energies.
"""
species, aev = input
output = aev.new_zeros(species.shape)
per_atom_property = torch.zeros(
(species.shape[0], len(self.predicted_properties)),
dtype=aev.dtype,
device=aev.device,
)

for i, model in enumerate(self.atomic_networks):
# maks all entries that don't contain atomindex=i
mask = torch.eq(species, i)
midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
input_ = aev.index_select(0, midx)
output[midx] = model(input_).flatten()
per_element_index = mask.nonzero().flatten()
# if element present, pass it through the network
if per_element_index.shape[0] > 0:
input_ = aev.index_select(0, per_element_index)
per_element_predction = model(input_)
per_atom_property[per_element_index, :] = per_element_predction

return output.view_as(species)
return per_atom_property


class ANI2xCore(CoreNetwork):
Expand Down Expand Up @@ -512,12 +549,14 @@ def __init__(
activation_function: Type[torch.nn.Module],
angular_dist_divisions: int,
angle_sections: int,
predicted_properties: List[Dict[str, str]],
) -> None:
# number of elements in ANI2x
self.num_species = 7

log.debug("Initializing the ANI2x architecture.")
super().__init__(activation_function)
self.predicted_properties = predicted_properties

# Initialize representation block
self.ani_representation_module = ANIRepresentation(
Expand Down Expand Up @@ -545,6 +584,7 @@ def __init__(
self.interaction_modules = ANIInteraction(
aev_dim=self.aev_length,
activation_function=self.activation_function,
predicted_properties=predicted_properties,
)

# ----- ATOMIC NUMBER LOOKUP --------
Expand Down Expand Up @@ -611,16 +651,21 @@ def compute_properties(self, data: AniNeuralNetworkData) -> Dict[str, torch.Tens

# compute the representation (atomic environment vectors) for each atom
representation = self.ani_representation_module(data)
# compute the atomic energies
E_i = self.interaction_modules(representation)

return {
"per_atom_energy": E_i,
# compute the atomic properties
predictions = self.interaction_modules(representation)
# generate the output results
results = {
"per_atom_scalar_representation": torch.tensor([0]),
"atomic_subsystem_indices": data.atomic_subsystem_indices,
}
# extract predictions per property
for dim, property in enumerate(self.predicted_properties):
results[property["name"]] = predictions[:, dim]

return results


from typing import Union, Optional, Dict
from typing import Dict, Optional, Union


class ANI2x(BaseNetwork):
Expand Down Expand Up @@ -665,6 +710,7 @@ def __init__(
angle_sections: int,
activation_function_parameter: Dict,
postprocessing_parameter: Dict[str, Dict[str, bool]],
predicted_properties: List[Dict[str, str]],
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
) -> None:
Expand Down Expand Up @@ -695,6 +741,7 @@ def __init__(
activation_function=activation_function,
angular_dist_divisions=angular_dist_divisions,
angle_sections=angle_sections,
predicted_properties=predicted_properties,
)

def _config_prior(self):
Expand Down
3 changes: 2 additions & 1 deletion modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class CoreParameter(ParametersBase):
minimum_interaction_radius_for_angular_features: Union[str, unit.Quantity]
angular_dist_divisions: int
activation_function_parameter: ActivationFunctionConfig
predicted_properties: List[PredictedPropertiesParameter]

converted_units = field_validator(
"maximum_interaction_radius",
Expand Down Expand Up @@ -275,7 +276,7 @@ class Featurization(ParametersBase):
featurization: Featurization
activation_function_parameter: ActivationFunctionConfig
predicted_properties: List[PredictedPropertiesParameter]

converted_units = field_validator("maximum_interaction_radius")(
_convert_str_to_unit
)
Expand Down
5 changes: 5 additions & 0 deletions modelforge/tests/data/potential_defaults/ani2x.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ number_of_radial_basis_functions = 16
maximum_interaction_radius_for_angular_features = "3.5 angstrom"
minimum_interaction_radius_for_angular_features = "0.8 angstrom"
angular_dist_divisions = 8
predicted_properties = [
{ name = "per_atom_energy", type = "scalar" },
{ name = "per_atom_charge", type = "scalar" },
# we could also define per_atom_force to be consistent?
]

[potential.core_parameter.activation_function_parameter]
activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used.
Expand Down

0 comments on commit 5ecf6bc

Please sign in to comment.