Skip to content

Commit

Permalink
extracting parilist from schnet implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Sep 13, 2023
1 parent 6a5dfe6 commit 56cfa31
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 83 deletions.
5 changes: 3 additions & 2 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
Coordinates for each atom in the molecule.
- 'E': torch.Tensor, shape []
Scalar energy value for the molecule.
- 'idx': int
Index of the molecule in the dataset.
"""
Z = torch.tensor(self.properties_of_interest["Z"][idx], dtype=torch.int64)
R = torch.tensor(self.properties_of_interest["R"][idx], dtype=torch.float32)
E = torch.tensor(self.properties_of_interest["E"][idx], dtype=torch.float32)
return {"Z": Z, "R": R, "E": E}
return {"Z": Z, "R": R, "E": E, "idx": idx}


class HDF5Dataset:
Expand Down
43 changes: 43 additions & 0 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,49 @@
from modelforge.utils import SpeciesEnergies


class PairList(nn.Module):
def __init__(self, cutoff: float = 5.0):
"""
Initialize the PairList class.
"""
super().__init__()
from .utils import neighbor_pairs_nopbc

self.calculate_neighbors = neighbor_pairs_nopbc
self.cutoff = cutoff

def compute_distance(
self, atom_index12: torch.Tensor, R: torch.Tensor
) -> torch.Tensor:
"""
Compute distances based on atom indices and coordinates.
Parameters
----------
atom_index12 : torch.Tensor, shape [n_pairs, 2]
Atom indices for pairs of atoms
R : torch.Tensor, shape [batch_size, n_atoms, n_dims]
Atom coordinates.
Returns
-------
torch.Tensor, shape [n_pairs]
Computed distances.
"""

coordinates = R.flatten(0, 1)
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(
2, -1, 3
)
vec = selected_coordinates[0] - selected_coordinates[1]
return vec.norm(2, -1)

def forward(self, mask, R) -> Dict[str, torch.Tensor]:
atom_index12 = self.calculate_neighbors(mask, R, self.cutoff)
d_ij = self.compute_distance(atom_index12, R)
return {"atom_index12": atom_index12, "d_ij": d_ij}


class BaseNNP(pl.LightningModule):
"""
Abstract base class for neural network potentials.
Expand Down
124 changes: 45 additions & 79 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn
from loguru import logger
from typing import Dict, Tuple
from typing import Dict, Tuple, List

from .models import BaseNNP
from .utils import (
Expand All @@ -15,7 +15,11 @@

class Schnet(BaseNNP):
def __init__(
self, n_atom_basis: int, n_interactions: int, n_filters: int = 0
self,
n_atom_basis: int,
n_interactions: int,
n_filters: int = 0,
cutoff: float = 5.0,
) -> None:
"""
Initialize the Schnet class.
Expand All @@ -30,16 +34,24 @@ def __init__(
n_filters : int, optional
Number of filters, defines the dimensionality of the intermediate features.
Default is 0.
cutoff : float, optional
Cutoff value for the pairlist. Default is 5.0.
"""
from .models import PairList

super().__init__()

self.calculate_distances_and_pairlist = PairList(cutoff)

self.representation = SchNetRepresentation(
n_atom_basis, n_filters, n_interactions
)
self.readout = EnergyReadout(n_atom_basis)
self.embedding = nn.Embedding(100, n_atom_basis, padding_idx=-1)

def calculate_energy(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
def calculate_energy(
self, inputs: Dict[str, torch.Tensor], cached_pairlist: bool = False
) -> torch.Tensor:
"""
Calculate the energy for a given input batch.
Expand All @@ -50,15 +62,22 @@ def calculate_energy(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
Atomic numbers for each atom in each molecule in the batch.
- 'R': torch.Tensor, shape [batch_size, n_atoms, 3]
Coordinates for each atom in each molecule in the batch.
cached_pairlist : bool, optional
Whether to use a cached pairlist. Default is False. NOTE: is this really needed?
Returns
-------
torch.Tensor, shape [batch_size]
Calculated energies for each molecule in the batch.
"""
x = self.representation(inputs)
logger.debug(f"{x.shape=}")
# compute atom and pair features (see Fig1 in 10.1063/1.5019779)
# initializing x^{l}_{0} as x^l)0 = aZ_i
Z = inputs["Z"]
x = self.embedding(Z)
mask = Z == -1
pairlist = self.calculate_distances_and_pairlist(mask, inputs["R"])

x = self.representation(x, pairlist)
# pool average over atoms
return self.readout(x)

Expand Down Expand Up @@ -136,39 +155,31 @@ def forward(
"""
batch_size, nr_of_atoms = x.shape[0], x.shape[1]

logger.debug(f"Input to feature: {x.shape=}")
logger.debug(f"Input to feature: {f_ij.shape=}")
logger.debug(f"Input to feature: {idx_i.shape=}")
logger.debug(f"Input to feature: {rcut_ij.shape=}")
x = self.intput_to_feature(x)
logger.debug(f"After input_to_feature call: {x.shape=}")
x = x.flatten(0, 1)
logger.debug(f"Flatten x: {x.shape=}")

# Filter generation networks
Wij = self.filter_network(f_ij)
Wij = Wij * rcut_ij[:, None]
Wij = Wij.to(dtype=x.dtype)
logger.debug(f"Wij {Wij.shape=}")

# continuous-filter convolutional layers
logger.debug(f"Before x[idx_j]: x.shape {x.shape=}")
logger.debug(f"idx_j.shape {idx_j.shape=}")
x_j = x[idx_j]
x_ij = x_j * Wij
logger.debug(f"After x_j * Wij: x_ij.shape {x_ij.shape=}")
x = scatter_add(x_ij, idx_i, dim_size=x.shape[0])
logger.debug(f"After scatter_add: x.shape {x.shape=}")
# Update features
x = self.feature_to_output(x)
logger.debug(f"After feature_to_output: x.shape {x.shape=}")
x = x.reshape(batch_size, nr_of_atoms, 128)
logger.debug(f"After reshape: x.shape {x.shape=}")
return x


class SchNetRepresentation(nn.Module):
def __init__(self, n_atom_basis: int, n_filters: int, n_interactions: int):
def __init__(
self,
n_atom_basis: int,
n_filters: int,
n_interactions: int,
):
"""
Initialize the SchNet representation layer.
Expand All @@ -182,9 +193,7 @@ def __init__(self, n_atom_basis: int, n_filters: int, n_interactions: int):
Number of interaction layers.
"""
super().__init__()
from .utils import neighbor_pairs_nopbc

self.embedding = nn.Embedding(100, n_atom_basis, padding_idx=-1)
self.interactions = nn.ModuleList(
[
SchNetInteractionBlock(n_atom_basis, n_filters)
Expand All @@ -193,7 +202,6 @@ def __init__(self, n_atom_basis: int, n_filters: int, n_interactions: int):
)
self.cutoff = 5.0
self.radial_basis = GaussianRBF(n_rbf=20, cutoff=self.cutoff)
self.calculate_neighbors = neighbor_pairs_nopbc

def _distance_to_radial_basis(
self, d_ij: torch.Tensor
Expand All @@ -216,77 +224,35 @@ def _distance_to_radial_basis(
rcut_ij = cosine_cutoff(d_ij, self.cutoff)
return f_ij, rcut_ij

def compute_distance(
self, atom_index12: torch.Tensor, R: torch.Tensor
def forward(
self, x: torch.Tensor, pairlist: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Compute distances based on atom indices and coordinates.
Parameters
----------
atom_index12 : torch.Tensor, shape [n_pairs, 2]
Atom indices for pairs of atoms
R : torch.Tensor, shape [batch_size, n_atoms, n_dims]
Atom coordinates.
Returns
-------
torch.Tensor, shape [n_pairs]
Computed distances.
"""

logger.debug(f"{atom_index12.shape=}")
logger.debug(f"{R.shape=}")
coordinates = R.flatten(0, 1)
logger.debug(f"{coordinates.shape=}")
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(
2, -1, 3
)
logger.debug(f"{selected_coordinates.shape=}")
vec = selected_coordinates[0] - selected_coordinates[1]
return vec.norm(2, -1)

def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Forward pass for the representation layer.
Parameters
----------
inputs : Dict[str, torch.Tensor]
Dictionary containing input tensors, specifically atomic numbers and coordinates.
- 'Z': Atomic numbers, shape [batch_size, n_atoms]
- 'R': Atom coordinates, shape [batch_size, n_atoms, 3]
x : torch.Tensor, shape [batch_size, n_atoms, n_atom_basis]
Input feature tensor for atoms.
pairlist: Dict[str, torch.Tensor]
Pairlist dictionary containing the following keys:
- 'atom_index12': torch.Tensor, shape [n_pairs, 2]
Atom indices for pairs of atoms
- 'd_ij': torch.Tensor, shape [n_pairs]
Pairwise distances between atoms.
Returns
-------
torch.Tensor, shape [batch_size, n_atoms, n_atom_basis]
Output tensor after forward pass.
"""
logger.debug("Compute distances ...")
Z = inputs["Z"]
R = inputs["R"]
mask = Z == -1
atom_index12 = pairlist["atom_index12"]
d_ij = pairlist["d_ij"]

atom_index12 = self.calculate_neighbors(mask, R, self.cutoff)
d_ij = self.compute_distance(atom_index12, R)
logger.debug(f"{d_ij.shape=}")
logger.debug("Convert distances to radial basis ...")
f_ij, rcut_ij = self._distance_to_radial_basis(d_ij)
logger.debug("Compute interaction block ...")

# compute atom and pair features (see Fig1 in 10.1063/1.5019779)
# initializing x^{l}_{0} as x^l)0 = aZ_i
logger.debug("Embedding inputs.Z")
logger.debug(f"{Z.shape=}")
x = self.embedding(Z)

logger.debug(f"After embedding: {x.shape=}")
idx_i = atom_index12[0]
idx_j = atom_index12[1]
idx_i, idx_j = atom_index12[0], atom_index12[1]
for interaction in self.interactions:
v = interaction(x, f_ij, idx_i, idx_j, rcut_ij)
x = x + v


return x
4 changes: 2 additions & 2 deletions modelforge/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_different_properties_of_interest(dataset):
dataset = factory.create_dataset(data)
raw_data_item = dataset[0]
assert isinstance(raw_data_item, dict)
assert len(raw_data_item) == 3
assert len(raw_data_item) == 4

data.properties_of_interest = ["return_energy", "geometry"]
assert data.properties_of_interest == [
Expand All @@ -82,7 +82,7 @@ def test_different_properties_of_interest(dataset):
raw_data_item = dataset[0]
print(raw_data_item)
assert isinstance(raw_data_item, dict)
assert len(raw_data_item) != 2 # NOTE: FIXME: This should be 2
assert len(raw_data_item) != 3


@pytest.mark.parametrize("dataset", DATASETS)
Expand Down

0 comments on commit 56cfa31

Please sign in to comment.