Skip to content

Commit

Permalink
Initial Base Classes and base functionality for Neural Network Potent…
Browse files Browse the repository at this point in the history
…ials (#5)

* delete notebook

* first outline of base NNP class

* remove useless example

* updating tests

* change dtype naming

* simplify base model

* update base implementation

* reference schnet impoementation

* adding imports, removing alternative implementstions

* adding test

* still work in progress

* adding comments

* adding docstrings

* adding tests

* bugfix

* remove print statments

* bugfix
  • Loading branch information
wiederm authored Sep 4, 2023
1 parent c8284eb commit 815c69c
Show file tree
Hide file tree
Showing 13 changed files with 672 additions and 151 deletions.
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- lightning
- tensorboard
- torchvision
- ase

# Testing
- pytest
Expand Down
2 changes: 2 additions & 0 deletions modelforge/potential/__initi__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .schnet import Schnet
from .utils import Dense, GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add
Empty file removed modelforge/potential/features.py
Empty file.
67 changes: 67 additions & 0 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Dict, List, Optional

import torch
import torch.nn as nn

from modelforge.utils import Inputs, Properties, SpeciesEnergies


class BaseNNP(nn.Module):
"""
Abstract base class for neural network potentials.
This class defines the overall structure and ensures that subclasses
implement the `calculate_energies_and_forces` method.
"""

def __init__(self, dtype: torch.dtype, device: torch.device):
"""
Initialize the NeuralNetworkPotential class.
Parameters
----------
dtype : torch.dtype
Data type for the PyTorch tensors.
device : torch.device
Device ("cpu" or "cuda") on which computations will be performed.
"""
super().__init__()
self.dtype = dtype
self.device = device

def forward(
self,
inputs: Inputs,
) -> SpeciesEnergies:
"""
Forward pass for the neural network potential.
Parameters
----------
inputs : Inputs
An instance of the Inputs data class containing atomic numbers, positions, etc.
Returns
-------
SpeciesEnergies
An instance of the SpeciesEnergies data class containing species and calculated energies.
"""

E = self.calculate_energies_and_forces(inputs)
return SpeciesEnergies(inputs.Z, E)

def calculate_energies_and_forces(self, inputs: Optional[Inputs] = None):
"""
Placeholder for the method that should calculate energies and forces.
This method should be implemented in subclasses.
Raises
------
NotImplementedError
If the method is not overridden in the subclass.
"""
raise NotImplementedError("Subclasses must implement this method.")


223 changes: 223 additions & 0 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from typing import Tuple
from loguru import logger

import numpy as np
import torch
import torch.nn as nn
from ase import Atoms
from ase.neighborlist import neighbor_list
from torch import dtype

from modelforge.utils import Inputs

from .models import BaseNNP
from .utils import Dense, GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add


class Schnet(BaseNNP):
"""
Implementation of the SchNet architecture for quantum mechanical property prediction.
"""

def __init__(
self,
n_atom_basis: int, # number of features per atom
n_interactions: int, # number of interaction blocks
n_filters: int = 0, # number of filters
dtype: dtype = torch.float32,
device: torch.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
):
"""
Initialize the SchNet model.
Parameters
----------
n_atom_basis : int
Number of features per atom.
n_interactions : int
Number of interaction blocks.
n_filters : int, optional
Number of filters, defaults to None.
dtype : torch.dtype, optional
Data type for PyTorch tensors, defaults to torch.float32.
device : torch.device, optional
Device ("cpu" or "cuda") on which computations will be performed.
"""

super().__init__(dtype, device)

# initialize atom embeddings
max_z: int = 100 # max nuclear charge (i.e. atomic number)
self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0)

# initialize radial basis functions and other constants
n_rbf = 20
self.radial_basis = GaussianRBF(n_rbf=n_rbf, cutoff=5.0)
self.cutoff = 5.0
self.activation = shifted_softplus
self.n_interactions = n_interactions
self.n_atom_basis = n_atom_basis

# initialize dense yalers for atom feature transformation
# Dense layers are applied consecutively to the initialized atom embeddings x^{l}_{0}
# to generate x_i^l+1 = W^lx^l_i + b^l
self.intput_to_feature = Dense(
n_atom_basis, n_filters, bias=False, activation=None
)
self.feature_to_output = nn.Sequential(
Dense(n_filters, n_atom_basis, activation=self.activation),
Dense(n_atom_basis, n_atom_basis, activation=None),
)

# Initialize filter network
self.filter_network = nn.Sequential(
Dense(n_rbf, n_filters, activation=self.activation),
Dense(n_filters, n_filters),
)

def _setup_ase_system(self, inputs: Inputs) -> Atoms:
"""
Transform inputs to an ASE Atoms object.
Parameters
----------
inputs : Inputs
Input features including atomic numbers and positions.
Returns
-------
ase.Atoms
Transformed ASE Atoms object.
"""
_atomic_numbers = torch.clone(inputs.Z)
atomic_numbers = list(_atomic_numbers.detach().cpu().numpy())
positions = list(inputs.R.detach().cpu().numpy())
ase_atoms = Atoms(numbers=atomic_numbers, positions=positions)
return ase_atoms

def _compute_distances(
self, inputs: Inputs
) -> Tuple[torch.Tensor, np.ndarray, np.ndarray]:
"""
Compute atomic distances using ASE's neighbor list.
Parameters
----------
inputs : Inputs
Input features including atomic numbers and positions.
Returns
-------
torch.Tensor, np.ndarray, np.ndarray
Pairwise distances, index of atom i, and index of atom j.
"""

ase_atoms = self._setup_ase_system(inputs)
idx_i, idx_j, _, r_ij = neighbor_list(
"ijSD", ase_atoms, 5.0, self_interaction=False
)
r_ij = torch.from_numpy(r_ij)
return r_ij, idx_i, idx_j

def _distance_to_radial_basis(self, r_ij):
"""
Transform distances to radial basis functions.
Parameters
----------
r_ij : torch.Tensor
Pairwise distances between atoms.
Returns
-------
torch.Tensor, torch.Tensor
Radial basis functions and cutoff values.
"""
d_ij = torch.norm(r_ij, dim=1) # calculate pairwise distances
f_ij = self.radial_basis(d_ij)
rcut_ij = cosine_cutoff(d_ij, self.cutoff)
return f_ij, rcut_ij

def _interaction_block(self, inputs: Inputs, f_ij, idx_i, idx_j, rcut_ij):
"""
Compute the interaction block which updates atom features.
Parameters
----------
inputs : Inputs
Input features including atomic numbers and positions.
f_ij : torch.Tensor
Radial basis functions.
idx_i : np.ndarray
Indices of center atoms.
idx_j : np.ndarray
Indices of neighboring atoms.
rcut_ij : torch.Tensor
Cutoff values for each pair of atoms.
Returns
-------
torch.Tensor
Updated atom features.
"""

# compute atom and pair features (see Fig1 in 10.1063/1.5019779)
# initializing x^{l}_{0} as x^l)0 = aZ_i
logger.debug("Embedding inputs.Z")
x_emb = self.embedding(inputs.Z)
logger.debug("After embedding: x.shape", x_emb.shape)
idx_i = torch.from_numpy(idx_i).to(self.device, torch.int64)

# interaction blocks
for _ in range(self.n_interactions):
# atom wise update of features
logger.debug(f"Input to feature: x.shape {x_emb.shape}")
x = self.intput_to_feature(x_emb)
logger.debug("After input_to_feature call: x.shape {x.shape}")

# Filter generation networks
Wij = self.filter_network(f_ij)
Wij = Wij * rcut_ij[:, None]
Wij = Wij.to(dtype=self.dtype)

# continuous-filter convolutional layers
x_j = x[idx_j]
x_ij = x_j * Wij
logger.debug("After x_j * Wij: x_ij.shape {x_ij.shape}")
x = scatter_add(x_ij, idx_i, dim_size=x.shape[0])
logger.debug("After scatter_add: x.shape {x.shape}")
# Update features
x = self.feature_to_output(x)
x_emb = x_emb + x

return x_emb

def calculate_energies_and_forces(self, inputs: Inputs) -> torch.Tensor:
"""
Compute energies and forces for given atomic configurations.
Parameters
----------
inputs : Inputs
Input features including atomic numbers and positions.
Returns
-------
torch.Tensor
Energies and forces for the given configurations.
"""
logger.debug("Compute distances ...")
r_ij, idx_i, idx_j = self._compute_distances(inputs)
logger.debug("Convert distances to radial basis ...")
f_ij, rcut_ij = self._distance_to_radial_basis(r_ij)
logger.debug("Compute interaction block ...")
x = self._interaction_block(inputs, f_ij, idx_i, idx_j, rcut_ij)
return x
Loading

0 comments on commit 815c69c

Please sign in to comment.