diff --git a/README.md b/README.md index 5b78f1c3..0592b769 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,12 @@ modelforge ============================== [//]: # (Badges) -[![GitHub Actions Build Status](https://github.com/REPLACE_WITH_OWNER_ACCOUNT/modelforge/workflows/CI/badge.svg)](https://github.com/REPLACE_WITH_OWNER_ACCOUNT/modelforge/actions?query=workflow%3ACI) -[![codecov](https://codecov.io/gh/REPLACE_WITH_OWNER_ACCOUNT/modelforge/branch/main/graph/badge.svg)](https://codecov.io/gh/REPLACE_WITH_OWNER_ACCOUNT/modelforge/branch/main) - +[![GitHub Actions Build Status](https://github.com/choderalab/modelforge/workflows/CI/badge.svg)](https://github.com/choderalab/modelforge/actions?query=workflow%3ACI) +[![codecov](https://codecov.io/gh/choderalab/modelforge/branch/main/graph/badge.svg)](https://codecov.io/gh/choderalab/modelforge/branch/main) +[![Github release](https://badgen.net/github/release/choderalab/modelforge)](https://github.com/choderalab/modelforge/) +[![GitHub license](https://img.shields.io/github/license/choderalab/modelforge?color=green)](https://github.com/choderalab/modelforge/blob/main/LICENSE) +[![GitHub issues](https://img.shields.io/github/issues/choderalab/modelforge?style=flat)](https://github.com/choderalab/modelforge/issues) +[![GitHub stars](https://img.shields.io/github/stars/choderalab/modelforge)](https://github.com/choderalab/modelforge/stargazers) Infrastructure to implement and train NNPs diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index bef01088..4ee2dd7c 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -21,7 +21,6 @@ dependencies: - torchvision - openff-units - pint - - ase # Testing - pytest diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 4dc08bac..a7d2ca97 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -8,6 +8,8 @@ from loguru import logger from torch.utils.data import DataLoader +from modelforge.utils.prop import PropertyNames + from .transformation import default_transformation from .utils import RandomSplittingStrategy, SplittingStrategy @@ -20,34 +22,28 @@ class TorchDataset(torch.utils.data.Dataset): ---------- dataset : np.ndarray The underlying numpy dataset. - prop : List[str] - List of property names to extract from the dataset. + property_name : PropertyNames + Property names to extract from the dataset. preloaded : bool, optional If True, preconverts the properties to PyTorch tensors to save time during item fetching. Default is False. - Examples - -------- - >>> numpy_data = np.load("data_file.npz") - >>> properties = ["geometry", "atomic_numbers"] - >>> torch_dataset = TorchDataset(numpy_data, properties) - >>> data_point = torch_dataset[0] """ def __init__( self, dataset: np.ndarray, - prop: List[str], + property_name: PropertyNames, preloaded: bool = False, ): - self.properties_of_interest = [dataset[p] for p in prop] - self.length = len(dataset[prop[0]]) - self.preloaded = preloaded + self.properties_of_interest = { + "Z": dataset[property_name.Z], + "R": dataset[property_name.R], + "E": dataset[property_name.E], + } - if preloaded: - self.properties_of_interest = [ - torch.tensor(p) for p in self.properties_of_interest - ] + self.length = len(self.properties_of_interest["Z"]) + self.preloaded = preloaded def __len__(self) -> int: """ @@ -60,7 +56,7 @@ def __len__(self) -> int: """ return self.length - def __getitem__(self, idx: int) -> Tuple[torch.Tensor]: + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """ Fetch a tuple of the values for the properties of interest for a given molecule index. @@ -71,20 +67,20 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor]: Returns ------- - Tuple[torch.Tensor] - Tuple containing tensors for properties of interest of the molecule. - - Examples - -------- - >>> data_point = torch_dataset[5] - >>> geometry, atomic_numbers = data_point + dict, contains: + - 'Z': torch.Tensor, shape [n_atoms] + Atomic numbers for each atom in the molecule. + - 'R': torch.Tensor, shape [n_atoms, 3] + Coordinates for each atom in the molecule. + - 'E': torch.Tensor, shape [] + Scalar energy value for the molecule. + - 'idx': int + Index of the molecule in the dataset. """ - if self.preloaded: - return tuple(prop[idx] for prop in self.properties_of_interest) - else: - return tuple( - torch.tensor(prop[idx]) for prop in self.properties_of_interest - ) + Z = torch.tensor(self.properties_of_interest["Z"][idx], dtype=torch.int64) + R = torch.tensor(self.properties_of_interest["R"][idx], dtype=torch.float32) + E = torch.tensor(self.properties_of_interest["E"][idx], dtype=torch.float32) + return {"Z": Z, "R": R, "E": E, "idx": idx} class HDF5Dataset: @@ -300,11 +296,10 @@ def create_dataset( logger.info(f"Creating {data.dataset_name} dataset") DatasetFactory._load_or_process_data(data, label_transform, transform) - return TorchDataset(data.numpy_data, data.properties_of_interest) + return TorchDataset(data.numpy_data, data._property_names) class TorchDataModule(pl.LightningDataModule): - """ A custom data module class to handle data loading and preparation for PyTorch Lightning training. @@ -346,7 +341,7 @@ def prepare_data(self) -> None: factory = DatasetFactory() self.dataset = factory.create_dataset(self.data) - def setup(self, stage: str): + def setup(self, stage: str) -> None: """ Splits the data into training, validation, and test sets based on the stage. @@ -365,11 +360,35 @@ def setup(self, stage: str): _, _, test_dataset = self.SplittingStrategy().split(self.dataset) self.test_dataset = test_dataset - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """ + Create a DataLoader for the training dataset. + + Returns + ------- + DataLoader + DataLoader containing the training dataset. + """ return DataLoader(self.train_dataset, batch_size=self.batch_size) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """ + Create a DataLoader for the validation dataset. + + Returns + ------- + DataLoader + DataLoader containing the validation dataset. + """ return DataLoader(self.val_dataset, batch_size=self.batch_size) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: + """ + Create a DataLoader for the test dataset. + + Returns + ------- + DataLoader + DataLoader containing the test dataset. + """ return DataLoader(self.test_dataset, batch_size=self.batch_size) diff --git a/modelforge/dataset/qm9.py b/modelforge/dataset/qm9.py index b6fa872b..aa252987 100644 --- a/modelforge/dataset/qm9.py +++ b/modelforge/dataset/qm9.py @@ -25,6 +25,14 @@ class QM9Dataset(HDF5Dataset): >>> data._download() """ + from modelforge.utils import PropertyNames + + _property_names = PropertyNames( + "atomic_numbers", + "geometry", + "return_energy", + ) + _available_properties = [ "geometry", "atomic_numbers", diff --git a/modelforge/dataset/utils.py b/modelforge/dataset/utils.py index 71b49158..87ca3a2d 100644 --- a/modelforge/dataset/utils.py +++ b/modelforge/dataset/utils.py @@ -205,7 +205,7 @@ def pad_to_max_length(data: List[np.ndarray]) -> List[np.ndarray]: max_length = max(len(arr) for arr in data) return [ - np.pad(arr, (0, max_length - len(arr)), "constant", constant_values=-1) + np.pad(arr, (0, max_length - len(arr)), "constant", constant_values=0) for arr in data ] diff --git a/modelforge/potential/__initi__.py b/modelforge/potential/__initi__.py index 61327ce3..746b040a 100644 --- a/modelforge/potential/__initi__.py +++ b/modelforge/potential/__initi__.py @@ -1,2 +1,2 @@ from .schnet import Schnet -from .utils import Dense, GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add +from .utils import GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index cace0b71..deafc51f 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -1,67 +1,155 @@ -from typing import Dict, List, Optional +from typing import Dict +import lightning as pl import torch import torch.nn as nn +from torch.optim import AdamW -from modelforge.utils import Inputs, Properties, SpeciesEnergies +from modelforge.utils import SpeciesEnergies -class BaseNNP(nn.Module): +class PairList(nn.Module): + def __init__(self, cutoff: float = 5.0): + """ + Initialize the PairList class. + """ + super().__init__() + from .utils import neighbor_pairs_nopbc + + self.calculate_neighbors = neighbor_pairs_nopbc + self.cutoff = cutoff + + def compute_distance( + self, atom_index12: torch.Tensor, R: torch.Tensor + ) -> torch.Tensor: + """ + Compute distances based on atom indices and coordinates. + + Parameters + ---------- + atom_index12 : torch.Tensor, shape [n_pairs, 2] + Atom indices for pairs of atoms + R : torch.Tensor, shape [batch_size, n_atoms, n_dims] + Atom coordinates. + + Returns + ------- + torch.Tensor, shape [n_pairs] + Computed distances. + """ + + coordinates = R.flatten(0, 1) + selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view( + 2, -1, 3 + ) + vec = selected_coordinates[0] - selected_coordinates[1] + return vec.norm(2, -1) + + def forward(self, mask, R) -> Dict[str, torch.Tensor]: + atom_index12 = self.calculate_neighbors(mask, R, self.cutoff) + d_ij = self.compute_distance(atom_index12, R) + return {"atom_index12": atom_index12, "d_ij": d_ij} + + +class BaseNNP(pl.LightningModule): """ Abstract base class for neural network potentials. This class defines the overall structure and ensures that subclasses implement the `calculate_energies_and_forces` method. + + Methods + ------- + forward(inputs: dict) -> SpeciesEnergies: + Forward pass for the neural network potential. + calculate_energy(inputs: dict) -> torch.Tensor: + Placeholder for the method that should calculate energies and forces. + training_step(batch, batch_idx) -> torch.Tensor: + Defines the train loop. + configure_optimizers() -> AdamW: + Configures the optimizer. """ - def __init__(self, dtype: torch.dtype, device: torch.device): + def __init__(self): """ - Initialize the NeuralNetworkPotential class. - - Parameters - ---------- - dtype : torch.dtype - Data type for the PyTorch tensors. - device : torch.device - Device ("cpu" or "cuda") on which computations will be performed. - + Initialize the NNP class. """ super().__init__() - self.dtype = dtype - self.device = device - def forward( - self, - inputs: Inputs, - ) -> SpeciesEnergies: + def forward(self, inputs: Dict[str, torch.Tensor]) -> SpeciesEnergies: """ Forward pass for the neural network potential. Parameters ---------- - inputs : Inputs - An instance of the Inputs data class containing atomic numbers, positions, etc. + inputs : dict + A dictionary containing atomic numbers, positions, etc. Returns ------- SpeciesEnergies An instance of the SpeciesEnergies data class containing species and calculated energies. - """ + assert isinstance(inputs, Dict) # + E = self.calculate_energy(inputs) + return SpeciesEnergies(inputs["Z"], E) - E = self.calculate_energies_and_forces(inputs) - return SpeciesEnergies(inputs.Z, E) - - def calculate_energies_and_forces(self, inputs: Optional[Inputs] = None): + def calculate_energy(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: """ Placeholder for the method that should calculate energies and forces. This method should be implemented in subclasses. + Parameters + ---------- + inputs : dict + A dictionary containing atomic numbers, positions, etc. + + Returns + ------- + torch.Tensor + The calculated energy tensor. + + Raises ------ NotImplementedError If the method is not overridden in the subclass. - """ raise NotImplementedError("Subclasses must implement this method.") + def training_step( + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """ + Defines the training loop. + + Parameters + ---------- + batch : dict + Batch data. + batch_idx : int + Batch index. + + Returns + ------- + torch.Tensor + The loss tensor. + """ + + E_hat = self.forward(batch) # wrap_vals_from_dataloader(batch)) + loss = nn.functional.mse_loss(E_hat.energies, batch["E"]) + # Logging to TensorBoard (if installed) by default + self.log("train_loss", loss) + return loss + + def configure_optimizers(self) -> AdamW: + """ + Configures the optimizer for training. + + Returns + ------- + AdamW + The AdamW optimizer. + """ + optimizer = AdamW(self.parameters(), lr=1e-3) + return optimizer diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index 6705bf01..6748ee46 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -1,223 +1,258 @@ -from typing import Tuple -from loguru import logger - -import numpy as np -import torch import torch.nn as nn -from ase import Atoms -from ase.neighborlist import neighbor_list -from torch import dtype - -from modelforge.utils import Inputs +from loguru import logger +from typing import Dict, Tuple, List from .models import BaseNNP -from .utils import Dense, GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add +from .utils import ( + EnergyReadout, + GaussianRBF, + ShiftedSoftplus, + cosine_cutoff, + scatter_add, +) +import torch class Schnet(BaseNNP): - """ - Implementation of the SchNet architecture for quantum mechanical property prediction. - """ - def __init__( self, - n_atom_basis: int, # number of features per atom - n_interactions: int, # number of interaction blocks - n_filters: int = 0, # number of filters - dtype: dtype = torch.float32, - device: torch.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ), - ): + n_atom_basis: int, + n_interactions: int, + n_filters: int = 0, + cutoff: float = 5.0, + ) -> None: """ - Initialize the SchNet model. + Initialize the Schnet class. + Parameters ---------- n_atom_basis : int - Number of features per atom. + Number of atom basis, defines the dimensionality of the output features. n_interactions : int - Number of interaction blocks. + Number of interaction blocks in the architecture. n_filters : int, optional - Number of filters, defaults to None. - dtype : torch.dtype, optional - Data type for PyTorch tensors, defaults to torch.float32. - device : torch.device, optional - Device ("cpu" or "cuda") on which computations will be performed. - + Number of filters, defines the dimensionality of the intermediate features. + Default is 0. + cutoff : float, optional + Cutoff value for the pairlist. Default is 5.0. """ + from .models import PairList - super().__init__(dtype, device) + super().__init__() - # initialize atom embeddings - max_z: int = 100 # max nuclear charge (i.e. atomic number) - self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0) + self.calculate_distances_and_pairlist = PairList(cutoff) - # initialize radial basis functions and other constants - n_rbf = 20 - self.radial_basis = GaussianRBF(n_rbf=n_rbf, cutoff=5.0) - self.cutoff = 5.0 - self.activation = shifted_softplus - self.n_interactions = n_interactions - self.n_atom_basis = n_atom_basis - - # initialize dense yalers for atom feature transformation - # Dense layers are applied consecutively to the initialized atom embeddings x^{l}_{0} - # to generate x_i^l+1 = W^lx^l_i + b^l - self.intput_to_feature = Dense( - n_atom_basis, n_filters, bias=False, activation=None - ) - self.feature_to_output = nn.Sequential( - Dense(n_filters, n_atom_basis, activation=self.activation), - Dense(n_atom_basis, n_atom_basis, activation=None), + self.representation = SchNetRepresentation( + n_atom_basis, n_filters, n_interactions ) + self.readout = EnergyReadout(n_atom_basis) + self.embedding = nn.Embedding(100, n_atom_basis, padding_idx=-1) - # Initialize filter network - self.filter_network = nn.Sequential( - Dense(n_rbf, n_filters, activation=self.activation), - Dense(n_filters, n_filters), - ) - - def _setup_ase_system(self, inputs: Inputs) -> Atoms: + def calculate_energy( + self, inputs: Dict[str, torch.Tensor], cached_pairlist: bool = False + ) -> torch.Tensor: """ - Transform inputs to an ASE Atoms object. + Calculate the energy for a given input batch. Parameters ---------- - inputs : Inputs - Input features including atomic numbers and positions. - + inputs : dict, contains + - 'Z': torch.Tensor, shape [batch_size, n_atoms] + Atomic numbers for each atom in each molecule in the batch. + - 'R': torch.Tensor, shape [batch_size, n_atoms, 3] + Coordinates for each atom in each molecule in the batch. + cached_pairlist : bool, optional + Whether to use a cached pairlist. Default is False. NOTE: is this really needed? Returns ------- - ase.Atoms - Transformed ASE Atoms object. + torch.Tensor, shape [batch_size] + Calculated energies for each molecule in the batch. + + """ + # compute atom and pair features (see Fig1 in 10.1063/1.5019779) + # initializing x^{l}_{0} as x^l)0 = aZ_i + Z = inputs["Z"] + x = self.embedding(Z) + mask = Z == -1 + pairlist = self.calculate_distances_and_pairlist(mask, inputs["R"]) + + x = self.representation(x, pairlist) + # pool average over atoms + return self.readout(x) + + +def sequential_block(in_features: int, out_features: int): + """ + Create a sequential block for the neural network. + + Parameters + ---------- + in_features : int + Number of input features. + out_features : int + Number of output features. + + Returns + ------- + nn.Sequential + Sequential layer block. + """ + return nn.Sequential( + nn.Linear(in_features, out_features), + ShiftedSoftplus(), + nn.Linear(out_features, out_features), + ) + + +class SchNetInteractionBlock(nn.Module): + def __init__(self, n_atom_basis: int, n_filters: int): + """ + Initialize the SchNet interaction block. + + Parameters + ---------- + n_atom_basis : int + Number of atom basis, defines the dimensionality of the output features. + n_filters : int + Number of filters, defines the dimensionality of the intermediate features. """ - _atomic_numbers = torch.clone(inputs.Z) - atomic_numbers = list(_atomic_numbers.detach().cpu().numpy()) - positions = list(inputs.R.detach().cpu().numpy()) - ase_atoms = Atoms(numbers=atomic_numbers, positions=positions) - return ase_atoms - - def _compute_distances( - self, inputs: Inputs - ) -> Tuple[torch.Tensor, np.ndarray, np.ndarray]: + super().__init__() + n_rbf = 20 + self.intput_to_feature = nn.Linear(n_atom_basis, n_filters) + self.feature_to_output = sequential_block(n_filters, n_atom_basis) + self.filter_network = sequential_block(n_rbf, n_filters) + + def forward( + self, + x: torch.Tensor, + f_ij: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + rcut_ij: torch.Tensor, + ) -> torch.Tensor: """ - Compute atomic distances using ASE's neighbor list. + Forward pass for the interaction block. Parameters ---------- - inputs : Inputs - Input features including atomic numbers and positions. + x : torch.Tensor, shape [batch_size, n_atoms, n_atom_basis] + Input feature tensor for atoms. + f_ij : torch.Tensor, shape [n_pairs, n_rbf] + Radial basis functions for pairs of atoms. + idx_i : torch.Tensor, shape [n_pairs] + Indices for the first atom in each pair. + idx_j : torch.Tensor, shape [n_pairs] + Indices for the second atom in each pair. + rcut_ij : torch.Tensor, shape [n_pairs] + Cutoff values for each pair. Returns ------- - torch.Tensor, np.ndarray, np.ndarray - Pairwise distances, index of atom i, and index of atom j. + torch.Tensor, shape [batch_size, n_atoms, n_atom_basis] + Updated feature tensor after interaction block. + """ + batch_size, nr_of_atoms = x.shape[0], x.shape[1] + + x = self.intput_to_feature(x) + x = x.flatten(0, 1) + + # Filter generation networks + Wij = self.filter_network(f_ij) + Wij = Wij * rcut_ij[:, None] + Wij = Wij.to(dtype=x.dtype) + + # continuous-filter convolutional layers + x_j = x[idx_j] + x_ij = x_j * Wij + x = scatter_add(x_ij, idx_i, dim_size=x.shape[0]) + # Update features + x = self.feature_to_output(x) + x = x.reshape(batch_size, nr_of_atoms, 128) + return x + +class SchNetRepresentation(nn.Module): + def __init__( + self, + n_atom_basis: int, + n_filters: int, + n_interactions: int, + ): """ + Initialize the SchNet representation layer. - ase_atoms = self._setup_ase_system(inputs) - idx_i, idx_j, _, r_ij = neighbor_list( - "ijSD", ase_atoms, 5.0, self_interaction=False + Parameters + ---------- + n_atom_basis : int + Number of atom basis. + n_filters : int + Number of filters. + n_interactions : int + Number of interaction layers. + """ + super().__init__() + + self.interactions = nn.ModuleList( + [ + SchNetInteractionBlock(n_atom_basis, n_filters) + for _ in range(n_interactions) + ] ) - r_ij = torch.from_numpy(r_ij) - return r_ij, idx_i, idx_j + self.cutoff = 5.0 + self.radial_basis = GaussianRBF(n_rbf=20, cutoff=self.cutoff) - def _distance_to_radial_basis(self, r_ij): + def _distance_to_radial_basis( + self, d_ij: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Transform distances to radial basis functions. + Convert distances to radial basis functions. Parameters ---------- - r_ij : torch.Tensor + d_ij : torch.Tensor, shape [n_pairs] Pairwise distances between atoms. Returns ------- - torch.Tensor, torch.Tensor - Radial basis functions and cutoff values. - + Tuple[torch.Tensor, torch.Tensor] + - Radial basis functions, shape [n_pairs, n_rbf] + - cutoff values, shape [n_pairs] """ - d_ij = torch.norm(r_ij, dim=1) # calculate pairwise distances f_ij = self.radial_basis(d_ij) rcut_ij = cosine_cutoff(d_ij, self.cutoff) return f_ij, rcut_ij - def _interaction_block(self, inputs: Inputs, f_ij, idx_i, idx_j, rcut_ij): + def forward( + self, x: torch.Tensor, pairlist: Dict[str, torch.Tensor] + ) -> torch.Tensor: """ - Compute the interaction block which updates atom features. + Forward pass for the representation layer. Parameters ---------- - inputs : Inputs - Input features including atomic numbers and positions. - f_ij : torch.Tensor - Radial basis functions. - idx_i : np.ndarray - Indices of center atoms. - idx_j : np.ndarray - Indices of neighboring atoms. - rcut_ij : torch.Tensor - Cutoff values for each pair of atoms. - + x : torch.Tensor, shape [batch_size, n_atoms, n_atom_basis] + Input feature tensor for atoms. + pairlist: Dict[str, torch.Tensor] + Pairlist dictionary containing the following keys: + - 'atom_index12': torch.Tensor, shape [n_pairs, 2] + Atom indices for pairs of atoms + - 'd_ij': torch.Tensor, shape [n_pairs] + Pairwise distances between atoms. Returns ------- - torch.Tensor - Updated atom features. - + torch.Tensor, shape [batch_size, n_atoms, n_atom_basis] + Output tensor after forward pass. """ + atom_index12 = pairlist["atom_index12"] + d_ij = pairlist["d_ij"] - # compute atom and pair features (see Fig1 in 10.1063/1.5019779) - # initializing x^{l}_{0} as x^l)0 = aZ_i - logger.debug("Embedding inputs.Z") - x_emb = self.embedding(inputs.Z) - logger.debug("After embedding: x.shape", x_emb.shape) - idx_i = torch.from_numpy(idx_i).to(self.device, torch.int64) - - # interaction blocks - for _ in range(self.n_interactions): - # atom wise update of features - logger.debug(f"Input to feature: x.shape {x_emb.shape}") - x = self.intput_to_feature(x_emb) - logger.debug("After input_to_feature call: x.shape {x.shape}") - - # Filter generation networks - Wij = self.filter_network(f_ij) - Wij = Wij * rcut_ij[:, None] - Wij = Wij.to(dtype=self.dtype) - - # continuous-filter convolutional layers - x_j = x[idx_j] - x_ij = x_j * Wij - logger.debug("After x_j * Wij: x_ij.shape {x_ij.shape}") - x = scatter_add(x_ij, idx_i, dim_size=x.shape[0]) - logger.debug("After scatter_add: x.shape {x.shape}") - # Update features - x = self.feature_to_output(x) - x_emb = x_emb + x - - return x_emb - - def calculate_energies_and_forces(self, inputs: Inputs) -> torch.Tensor: - """ - Compute energies and forces for given atomic configurations. + f_ij, rcut_ij = self._distance_to_radial_basis(d_ij) - Parameters - ---------- - inputs : Inputs - Input features including atomic numbers and positions. + idx_i, idx_j = atom_index12[0], atom_index12[1] + for interaction in self.interactions: + v = interaction(x, f_ij, idx_i, idx_j, rcut_ij) + x = x + v - Returns - ------- - torch.Tensor - Energies and forces for the given configurations. - - """ - logger.debug("Compute distances ...") - r_ij, idx_i, idx_j = self._compute_distances(inputs) - logger.debug("Convert distances to radial basis ...") - f_ij, rcut_ij = self._distance_to_radial_basis(r_ij) - logger.debug("Compute interaction block ...") - x = self._interaction_block(inputs, f_ij, idx_i, idx_j, rcut_ij) return x diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index ea5448a3..5ef1ffe5 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -3,15 +3,34 @@ import torch.nn as nn import numpy as np import torch.nn.functional as F +from loguru import logger def _scatter_add( - x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0 + src: torch.Tensor, index: torch.Tensor, dim_size: int, dim: int ) -> torch.Tensor: - shape = list(x.shape) + """ + Performs a scatter addition operation. + + Parameters + ---------- + src : torch.Tensor + Source tensor. + index : torch.Tensor + Index tensor. + dim_size : int + Dimension size. + dim : int + + Returns + ------- + torch.Tensor + The result of the scatter addition. + """ + shape = list(src.shape) shape[dim] = dim_size - tmp = torch.zeros(shape, dtype=x.dtype, device=x.device) - y = tmp.index_add(dim, idx_i, x) + tmp = torch.zeros(shape, dtype=src.dtype, device=src.device) + y = tmp.index_add(dim, index, src) return y @@ -34,69 +53,6 @@ def scatter_add( return _scatter_add(x, idx_i, dim_size, dim) -class Dense(nn.Linear): - r"""Fully connected linear layer with activation function. - - .. math:: - y = activation(x W^T + b) - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - activation: Union[Callable, nn.Module] = None, - dtype: torch.dtype = torch.float32, - device: torch.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ), - ): - """ - Initialize the Dense layer. - - Parameters - ---------- - in_features : int - Number of input features. - out_features : int - Number of output features. - bias : bool, optional - If False, the layer will not adapt bias. - activation : Callable or nn.Module, optional - Activation function, default is None (Identity). - dtype : torch.dtype, optional - Data type for PyTorch tensors. - device : torch.device, optional - Device ("cpu" or "cuda") on which computations will be performed. - """ - super().__init__(in_features, out_features, bias) - - # Initialize activation function - self.activation = activation if activation is not None else nn.Identity() - - # Initialize weight matrix - self.weight = nn.init.xavier_uniform_(self.weight).to(device, dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass through the layer. - - Parameters - ---------- - x : torch.Tensor - Input tensor. - - Returns - ------- - torch.Tensor - Transformed tensor. - """ - - y = F.linear(x, self.weight, self.bias) - y = self.activation(y) - return y - def gaussian_rbf( inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor @@ -125,31 +81,73 @@ def gaussian_rbf( return y.to(dtype=torch.float32) -def cosine_cutoff(input: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor: +def cosine_cutoff(d_ij: torch.Tensor, cutoff: float) -> torch.Tensor: """ - Behler-style cosine cutoff function. + Compute the cosine cutoff for a distance tensor. Parameters ---------- - inputs : torch.Tensor - Input tensor. - cutoff : torch.Tensor - Cutoff radius. - + d_ij : torch.Tensor + Pairwise distance tensor. + cutoff : float + Cutoff distance. Returns ------- torch.Tensor - Transformed tensor. + The cosine cutoff tensor. """ # Compute values of cutoff function - input_cut = 0.5 * (torch.cos(input * np.pi / cutoff) + 1.0) + input_cut = 0.5 * (torch.cos(d_ij * np.pi / cutoff) + 1.0) # Remove contributions beyond the cutoff radius - input_cut *= input < cutoff + input_cut *= d_ij < cutoff return input_cut -def shifted_softplus(x: torch.Tensor) -> torch.Tensor: +class EnergyReadout(nn.Module): + """ + Defines the energy readout module. + + Methods + ------- + forward(x: torch.Tensor) -> torch.Tensor: + Forward pass for the energy readout. + """ + + def __init__(self, n_atom_basis: int): + """ + Initialize the EnergyReadout class. + + Parameters + ---------- + n_atom_basis : int + Number of atom basis. + """ + super().__init__() + self.energy_layer = nn.Linear(n_atom_basis, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the energy readout. + + Parameters + ---------- + x : torch.Tensor, shape [batch, n_atoms, n_atom_basis] + Input tensor for the forward pass. + + Returns + ------- + torch.Tensor + The output tensor. + """ + x = self.energy_layer( + x + ) # in [batch, n_atoms, n_atom_basis], out [batch, n_atoms, 1] + total_energy = x.sum(dim=1) # in [batch, n_atoms, 1], out [batch, 1] + return total_energy + + +class ShiftedSoftplus(nn.Module): """ Compute shifted soft-plus activation function. @@ -163,73 +161,111 @@ def shifted_softplus(x: torch.Tensor) -> torch.Tensor: torch.Tensor Transformed tensor. """ - return nn.functional.softplus(x) - np.log(2.0) + + def __init__(self): + super().__init__() + + def forward(self, x): + return nn.functional.softplus(x) - np.log(2.0) class GaussianRBF(nn.Module): """ - Gaussian radial basis functions (RBF). + Gaussian Radial Basis Function module. + + Methods + ------- + forward(x: torch.Tensor) -> torch.Tensor: + Forward pass for the GaussianRBF. """ def __init__( self, n_rbf: int, cutoff: float, - start: float = 0.0, - trainable: bool = False, - device: torch.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ), - dtype: torch.dtype = torch.float32, ): """ - Initialize Gaussian RBF layer. + Initialize the GaussianRBF class. Parameters ---------- n_rbf : int Number of radial basis functions. cutoff : float - Cutoff distance for RBF. - start : float, optional - Starting distance for RBF, defaults to 0.0. - trainable : bool, optional - If True, widths and offsets are trainable parameters. - device : torch.device, optional - Device ("cpu" or "cuda") on which computations will be performed. - dtype : torch.dtype, optional - Data type for PyTorch tensors. - + The cutoff distance. """ super().__init__() self.n_rbf = n_rbf # compute offset and width of Gaussian functions - offset = torch.linspace(start, cutoff, n_rbf, dtype=dtype, device=device) + offset = torch.linspace(0, cutoff, n_rbf) widths = torch.tensor( torch.abs(offset[1] - offset[0]) * torch.ones_like(offset), - device=device, - dtype=dtype, ) - if trainable: - self.widths = nn.Parameter(widths) - self.offsets = nn.Parameter(offset) - else: - self.register_buffer("widths", widths) - self.register_buffer("offsets", offset) + self.register_buffer("widths", widths) + self.register_buffer("offsets", offset) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ - Forward pass through the layer. + Forward pass for the GaussianRBF. Parameters ---------- - inputs : torch.Tensor - Input tensor. + x : torch.Tensor + Input tensor for the forward pass. Returns ------- torch.Tensor - Transformed tensor. + The output tensor. """ return gaussian_rbf(inputs, self.offsets, self.widths) + + +# taken from torchani repository: https://github.com/aiqm/torchani +def neighbor_pairs_nopbc( + mask: torch.Tensor, R: torch.Tensor, cutoff: float +) -> torch.Tensor: + """ + Calculate neighbor pairs without periodic boundary conditions. + Parameters + ---------- + mask : torch.Tensor + Mask tensor to indicate invalid atoms, shape (batch_size, n_atoms). + R : torch.Tensor + Coordinates tensor, shape (batch_size, n_atoms, 3). + cutoff : float + Cutoff distance for neighbors. + + Returns + ------- + torch.Tensor + Tensor containing indices of neighbor pairs, shape (n_pairs, 2). + + Notes + ----- + This function assumes no periodic boundary conditions and calculates neighbor pairs based solely on the cutoff distance. + + Examples + -------- + >>> mask = torch.tensor([[0, 0, 1], [1, 0, 0]]) + >>> R = torch.tensor([[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]],[[3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0]]]) + >>> cutoff = 1.5 + >>> neighbor_pairs_nopbc(mask, R, cutoff) + """ + import math + + R = R.detach().masked_fill(mask.unsqueeze(-1), math.nan) + current_device = R.device + num_atoms = mask.shape[1] + num_mols = mask.shape[0] + p12_all = torch.triu_indices(num_atoms, num_atoms, 1, device=current_device) + p12_all_flattened = p12_all.view(-1) + + pair_coordinates = R.index_select(1, p12_all_flattened).view(num_mols, 2, -1, 3) + distances = (pair_coordinates[:, 0, ...] - pair_coordinates[:, 1, ...]).norm(2, -1) + in_cutoff = (distances <= cutoff).nonzero() + molecule_index, pair_index = in_cutoff.unbind(1) + molecule_index *= num_atoms + atom_index12 = p12_all[:, pair_index] + molecule_index + return atom_index12 diff --git a/modelforge/tests/helper_functinos.py b/modelforge/tests/helper_functinos.py index 217cccd4..8de7a0be 100644 --- a/modelforge/tests/helper_functinos.py +++ b/modelforge/tests/helper_functinos.py @@ -2,32 +2,50 @@ from modelforge.dataset.dataset import TorchDataModule from modelforge.dataset.qm9 import QM9Dataset +from modelforge.potential.schnet import Schnet from modelforge.utils import Inputs +from modelforge.potential.models import BaseNNP +from typing import Optional -def default_input(): - train_loader = initialize_dataloader() - R, Z, E = train_loader.dataset[0] - padded_values = -Z.eq(-1).sum().item() - Z_ = Z[:padded_values] - R_ = R[:padded_values] - return Inputs(Z_, R_, E) +MODELS_TO_TEST = [Schnet] +DATASETS = [QM9Dataset] -def initialize_dataloader() -> torch.utils.data.DataLoader: - data = QM9Dataset(for_unit_testing=True) +def setup_simple_model(model_class) -> Optional[BaseNNP]: + if model_class is Schnet: + return Schnet(n_atom_basis=128, n_interactions=3, n_filters=64) + else: + raise NotImplementedError + + +def return_single_batch(dataset, mode: str): + train_loader = initialize_dataset(dataset, mode) + for batch in train_loader.train_dataloader(): + return batch + + +def initialize_dataset(dataset, mode: str) -> TorchDataModule: + data = dataset(for_unit_testing=True) data_module = TorchDataModule(data) data_module.prepare_data() - data_module.setup("fit") - return data_module.train_dataloader() - - -methane_coordinates = torch.tensor( - [ - [0.0, 0.0, 0.0], - [0.63918859, 0.63918859, 0.63918859], - [-0.63918859, -0.63918859, 0.63918859], - [-0.63918859, 0.63918859, -0.63918859], - [0.63918859, -0.63918859, -0.63918859], - ] -) + data_module.setup(mode) + return data_module + + +def methane_input(): + Z = torch.tensor([[6, 1, 1, 1, 1]], dtype=torch.int64) + R = torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [0.63918859, 0.63918859, 0.63918859], + [-0.63918859, -0.63918859, 0.63918859], + [-0.63918859, 0.63918859, -0.63918859], + [0.63918859, -0.63918859, -0.63918859], + ] + ] + ) + E = torch.tensor([0.0]) + return {"Z": Z, "R": R, "E": E} + diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index 74b28283..17254511 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -6,10 +6,10 @@ import torch from torch.utils.data import DataLoader -from modelforge.dataset.dataset import DatasetFactory, TorchDataset +from modelforge.dataset.dataset import DatasetFactory, TorchDataModule, TorchDataset from modelforge.dataset.qm9 import QM9Dataset -DATASETS = [QM9Dataset] +from .helper_functinos import initialize_dataset, DATASETS @pytest.fixture( @@ -42,7 +42,7 @@ def _cleanup(): _cleanup() -def generate_dataset(dataset) -> TorchDataset: +def generate_torch_dataset(dataset) -> TorchDataset: factory = DatasetFactory() data = dataset(for_unit_testing=True) return factory.create_dataset(data) @@ -69,8 +69,8 @@ def test_different_properties_of_interest(dataset): dataset = factory.create_dataset(data) raw_data_item = dataset[0] - assert isinstance(raw_data_item, tuple) - assert len(raw_data_item) == 3 + assert isinstance(raw_data_item, dict) + assert len(raw_data_item) == 4 data.properties_of_interest = ["return_energy", "geometry"] assert data.properties_of_interest == [ @@ -81,8 +81,8 @@ def test_different_properties_of_interest(dataset): dataset = factory.create_dataset(data) raw_data_item = dataset[0] print(raw_data_item) - assert isinstance(raw_data_item, tuple) - assert len(raw_data_item) == 2 + assert isinstance(raw_data_item, dict) + assert len(raw_data_item) != 3 @pytest.mark.parametrize("dataset", DATASETS) @@ -119,14 +119,18 @@ def test_different_scenarios_of_file_availability(dataset): @pytest.mark.parametrize("dataset", DATASETS) def test_data_item_format(dataset): """Test the format of individual data items in the dataset.""" - dataset = generate_dataset(dataset) + from typing import Dict - raw_data_item = dataset[0] - assert isinstance(raw_data_item, tuple) - assert len(raw_data_item) == 3 - assert isinstance(raw_data_item[0], torch.Tensor) - assert isinstance(raw_data_item[1], torch.Tensor) - assert isinstance(raw_data_item[2], torch.Tensor) + dataset = initialize_dataset(dataset, mode="fit") + + raw_data_item = dataset.dataset[0] + assert isinstance(raw_data_item, Dict) + assert isinstance(raw_data_item["Z"], torch.Tensor) + assert isinstance(raw_data_item["R"], torch.Tensor) + assert isinstance(raw_data_item["E"], torch.Tensor) + print(raw_data_item) + + assert raw_data_item["Z"].shape[0] == raw_data_item["R"].shape[0] def test_padding(): @@ -150,12 +154,23 @@ def test_dataset_generation(dataset): """Test the splitting of the dataset.""" from modelforge.dataset.utils import RandomSplittingStrategy - dataset = generate_dataset(dataset) - train_dataset, val_dataset, test_dataset = RandomSplittingStrategy().split(dataset) + dataset = initialize_dataset(dataset, mode="fit") + train_dataloader = dataset.train_dataloader() + val_dataloader = dataset.val_dataloader() - assert len(train_dataset) == 80 - assert len(test_dataset) == 10 - assert len(val_dataset) == 10 + try: + dataset.test_dataloader() + except AttributeError: + # this isn't set when dataset is in 'fit' mode + pass + + # the dataloader automatically splits and batches the dataset + # for the trianing set it batches the 80 datapoints in + # a batch of 64 and a batch of 16 samples + assert len(train_dataloader) == 2 # nr of batches + v = [v_ for v_ in train_dataloader] + assert len(v[0]["Z"]) == 64 + assert len(v[1]["Z"]) == 16 @pytest.mark.parametrize("dataset", DATASETS) @@ -163,10 +178,10 @@ def test_dataset_splitting(dataset): """Test random_split on the the dataset.""" from modelforge.dataset.utils import RandomSplittingStrategy - dataset = generate_dataset(dataset) + dataset = generate_torch_dataset(dataset) train_dataset, val_dataset, test_dataset = RandomSplittingStrategy().split(dataset) - energy = train_dataset[0][2].item() + energy = train_dataset[0]["E"].item() assert np.isclose(energy, -157.09958704371914) print(energy) @@ -192,7 +207,7 @@ def test_file_cache_methods(dataset): # generate files to test _from_hdf5() from modelforge.dataset.transformation import default_transformation - _ = generate_dataset(dataset) + _ = initialize_dataset(dataset, mode="str") data = dataset(for_unit_testing=True) @@ -226,22 +241,3 @@ def test_numpy_dataset_assignment(dataset): assert hasattr(data, "numpy_data") assert isinstance(data.numpy_data, np.lib.npyio.NpzFile) - - -@pytest.mark.parametrize("dataset", DATASETS) -def test_dataset_dataloaders(dataset): - """ - Test if the data loaders return the expected batch sizes. - """ - from modelforge.dataset.utils import RandomSplittingStrategy - from torch.utils.data import DataLoader - - dataset = generate_dataset(dataset) - train_dataset, val_dataset, test_dataset = RandomSplittingStrategy().split(dataset) - train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) - - for batch in train_dataloader: - assert len(batch) == 3 # coordinates, atomic_numbers, return_energy - assert ( - batch[0].size(0) == 64 or batch[0].size(0) == 16 - ) # default batch size (last batch has sieze 32) diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 25e3ed68..59a0f3d8 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -1,30 +1,25 @@ -from typing import Optional - import pytest -import torch from modelforge.potential.models import BaseNNP -from modelforge.potential.schnet import Schnet -from .helper_functinos import default_input - -MODELS_TO_TEST = [Schnet] - -def setup_simple_model(model_class) -> Optional[BaseNNP]: - if model_class is Schnet: - return Schnet(n_atom_basis=128, n_interactions=3, n_filters=64) - else: - raise NotImplementedError +from .helper_functinos import ( + DATASETS, + MODELS_TO_TEST, + return_single_batch, + setup_simple_model, +) def test_BaseNNP(): - nnp = BaseNNP(dtype=torch.float32, device="cpu") - assert nnp.dtype == torch.float32 - assert str(nnp.device) == "cpu" + nnp = BaseNNP() @pytest.mark.parametrize("model_class", MODELS_TO_TEST) -def test_forward_pass(model_class): +@pytest.mark.parametrize("dataset", DATASETS) +def test_forward_pass(model_class, dataset): initialized_model = setup_simple_model(model_class) - inputs = default_input() - output = initialized_model.forward(inputs) + inputs = return_single_batch(dataset, mode="fit") + output = initialized_model(inputs) + print(output.energies.shape) + assert output.species.shape[0] == inputs["Z"].shape[0] + assert output.energies.shape[0] == inputs["Z"].shape[0] diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index dd064c4e..13c77a5b 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -1,12 +1,29 @@ -import torch +from loguru import logger + from modelforge.potential.schnet import Schnet -from modelforge.utils import Inputs + +from .helper_functinos import methane_input +import torch def test_Schnet_init(): schnet = Schnet(128, 6, 2) - assert schnet.n_atom_basis == 128 - assert schnet.n_interactions == 6 + assert schnet is not None + + +def test_schnet_forward(): + model = Schnet(128, 3) + inputs = { + "Z": torch.tensor([[1, 2], [2, 3]]), + "R": torch.tensor( + [[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]] + ), + } + energy = model.calculate_energy(inputs) + assert energy.shape == ( + 2, + 1, + ) # Assuming energy is calculated per sample in the batch def test_calculate_energies_and_forces(): @@ -15,8 +32,7 @@ def test_calculate_energies_and_forces(): # energy and force calculatino on Methane schnet = Schnet(128, 6, 64) - Z = torch.tensor([1, 8], dtype=torch.int64) - R = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.96]], dtype=torch.float32) - inputs = Inputs(Z, R, torch.tensor([100])) - result = schnet.calculate_energies_and_forces(inputs) - assert result.shape[1] == 128 # Assuming n_atom_basis is 128 + methane_inputs = methane_input() + result = schnet.calculate_energy(methane_inputs) + logger.debug(result) + assert result.shape[0] == 1 # Assuming only one molecule diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py new file mode 100644 index 00000000..a274a9d4 --- /dev/null +++ b/modelforge/tests/test_training.py @@ -0,0 +1,26 @@ +from typing import Optional + +import pytest + +from modelforge.potential.models import BaseNNP +from modelforge.potential.schnet import Schnet + +from .helper_functinos import ( + MODELS_TO_TEST, + DATASETS, + initialize_dataset, + setup_simple_model, +) + + +@pytest.mark.parametrize("model_class", MODELS_TO_TEST) +@pytest.mark.parametrize("dataset", DATASETS) +def test_forward_pass(dataset, model_class): + from lightning import Trainer + import torch + + model = setup_simple_model(model_class) + dataset = initialize_dataset(dataset, mode="fit") + trainer = Trainer(max_epochs=2) + model = model.to(torch.float32) + trainer.fit(model, dataset.train_dataloader(), dataset.val_dataloader()) diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py index cea5d45a..b209fc60 100644 --- a/modelforge/tests/test_utils.py +++ b/modelforge/tests/test_utils.py @@ -1,7 +1,8 @@ import numpy as np -from modelforge.potential.utils import scatter_add, Dense, GaussianRBF +from modelforge.potential.utils import scatter_add, GaussianRBF import torch + def test_scatter_add(): x = torch.tensor([1, 4, 3, 2], dtype=torch.float32) idx_i = torch.tensor([0, 2, 2, 1], dtype=torch.int64) @@ -9,13 +10,6 @@ def test_scatter_add(): assert torch.equal(result, torch.tensor([1.0, 2.0, 7.0])) -def test_Dense(): - layer = Dense(2, 3) - x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) - y = layer(x) - assert y.shape == (2, 3) - - def test_GaussianRBF(): layer = GaussianRBF(10, 5.0) x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) diff --git a/modelforge/utils/__init__.py b/modelforge/utils/__init__.py index d7489235..81b63b83 100644 --- a/modelforge/utils/__init__.py +++ b/modelforge/utils/__init__.py @@ -1,3 +1,3 @@ """modelforge utilities.""" -from .prop import Properties, Inputs, SpeciesEnergies +from .prop import PropertyNames, Inputs, SpeciesEnergies diff --git a/modelforge/utils/prop.py b/modelforge/utils/prop.py index 30f4e428..66cbafcc 100644 --- a/modelforge/utils/prop.py +++ b/modelforge/utils/prop.py @@ -5,9 +5,10 @@ @dataclass -class Properties: - Z: str = "atomic_numbers" - R: str = "positions" +class PropertyNames: + Z: str + R: str + E: str @dataclass