Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/qm9dataset-update' into qm9datas…
Browse files Browse the repository at this point in the history
…et-update
  • Loading branch information
chrisiacovella committed Sep 14, 2023
2 parents e8bf0dd + a2d65dc commit ef7d93d
Show file tree
Hide file tree
Showing 17 changed files with 674 additions and 440 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
modelforge
==============================
[//]: # (Badges)
[![GitHub Actions Build Status](https://github.com/REPLACE_WITH_OWNER_ACCOUNT/modelforge/workflows/CI/badge.svg)](https://github.com/REPLACE_WITH_OWNER_ACCOUNT/modelforge/actions?query=workflow%3ACI)
[![codecov](https://codecov.io/gh/REPLACE_WITH_OWNER_ACCOUNT/modelforge/branch/main/graph/badge.svg)](https://codecov.io/gh/REPLACE_WITH_OWNER_ACCOUNT/modelforge/branch/main)

[![GitHub Actions Build Status](https://github.com/choderalab/modelforge/workflows/CI/badge.svg)](https://github.com/choderalab/modelforge/actions?query=workflow%3ACI)
[![codecov](https://codecov.io/gh/choderalab/modelforge/branch/main/graph/badge.svg)](https://codecov.io/gh/choderalab/modelforge/branch/main)
[![Github release](https://badgen.net/github/release/choderalab/modelforge)](https://github.com/choderalab/modelforge/)
[![GitHub license](https://img.shields.io/github/license/choderalab/modelforge?color=green)](https://github.com/choderalab/modelforge/blob/main/LICENSE)
[![GitHub issues](https://img.shields.io/github/issues/choderalab/modelforge?style=flat)](https://github.com/choderalab/modelforge/issues)
[![GitHub stars](https://img.shields.io/github/stars/choderalab/modelforge)](https://github.com/choderalab/modelforge/stargazers)

Infrastructure to implement and train NNPs

Expand Down
1 change: 0 additions & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ dependencies:
- torchvision
- openff-units
- pint
- ase

# Testing
- pytest
Expand Down
91 changes: 55 additions & 36 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from loguru import logger
from torch.utils.data import DataLoader

from modelforge.utils.prop import PropertyNames

from .transformation import default_transformation
from .utils import RandomSplittingStrategy, SplittingStrategy

Expand All @@ -20,34 +22,28 @@ class TorchDataset(torch.utils.data.Dataset):
----------
dataset : np.ndarray
The underlying numpy dataset.
prop : List[str]
List of property names to extract from the dataset.
property_name : PropertyNames
Property names to extract from the dataset.
preloaded : bool, optional
If True, preconverts the properties to PyTorch tensors to save time during item fetching.
Default is False.
Examples
--------
>>> numpy_data = np.load("data_file.npz")
>>> properties = ["geometry", "atomic_numbers"]
>>> torch_dataset = TorchDataset(numpy_data, properties)
>>> data_point = torch_dataset[0]
"""

def __init__(
self,
dataset: np.ndarray,
prop: List[str],
property_name: PropertyNames,
preloaded: bool = False,
):
self.properties_of_interest = [dataset[p] for p in prop]
self.length = len(dataset[prop[0]])
self.preloaded = preloaded
self.properties_of_interest = {
"Z": dataset[property_name.Z],
"R": dataset[property_name.R],
"E": dataset[property_name.E],
}

if preloaded:
self.properties_of_interest = [
torch.tensor(p) for p in self.properties_of_interest
]
self.length = len(self.properties_of_interest["Z"])
self.preloaded = preloaded

def __len__(self) -> int:
"""
Expand All @@ -60,7 +56,7 @@ def __len__(self) -> int:
"""
return self.length

def __getitem__(self, idx: int) -> Tuple[torch.Tensor]:
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Fetch a tuple of the values for the properties of interest for a given molecule index.
Expand All @@ -71,20 +67,20 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor]:
Returns
-------
Tuple[torch.Tensor]
Tuple containing tensors for properties of interest of the molecule.
Examples
--------
>>> data_point = torch_dataset[5]
>>> geometry, atomic_numbers = data_point
dict, contains:
- 'Z': torch.Tensor, shape [n_atoms]
Atomic numbers for each atom in the molecule.
- 'R': torch.Tensor, shape [n_atoms, 3]
Coordinates for each atom in the molecule.
- 'E': torch.Tensor, shape []
Scalar energy value for the molecule.
- 'idx': int
Index of the molecule in the dataset.
"""
if self.preloaded:
return tuple(prop[idx] for prop in self.properties_of_interest)
else:
return tuple(
torch.tensor(prop[idx]) for prop in self.properties_of_interest
)
Z = torch.tensor(self.properties_of_interest["Z"][idx], dtype=torch.int64)
R = torch.tensor(self.properties_of_interest["R"][idx], dtype=torch.float32)
E = torch.tensor(self.properties_of_interest["E"][idx], dtype=torch.float32)
return {"Z": Z, "R": R, "E": E, "idx": idx}


class HDF5Dataset:
Expand Down Expand Up @@ -300,11 +296,10 @@ def create_dataset(

logger.info(f"Creating {data.dataset_name} dataset")
DatasetFactory._load_or_process_data(data, label_transform, transform)
return TorchDataset(data.numpy_data, data.properties_of_interest)
return TorchDataset(data.numpy_data, data._property_names)


class TorchDataModule(pl.LightningDataModule):

"""
A custom data module class to handle data loading and preparation for PyTorch Lightning training.
Expand Down Expand Up @@ -346,7 +341,7 @@ def prepare_data(self) -> None:
factory = DatasetFactory()
self.dataset = factory.create_dataset(self.data)

def setup(self, stage: str):
def setup(self, stage: str) -> None:
"""
Splits the data into training, validation, and test sets based on the stage.
Expand All @@ -365,11 +360,35 @@ def setup(self, stage: str):
_, _, test_dataset = self.SplittingStrategy().split(self.dataset)
self.test_dataset = test_dataset

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
"""
Create a DataLoader for the training dataset.
Returns
-------
DataLoader
DataLoader containing the training dataset.
"""
return DataLoader(self.train_dataset, batch_size=self.batch_size)

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
"""
Create a DataLoader for the validation dataset.
Returns
-------
DataLoader
DataLoader containing the validation dataset.
"""
return DataLoader(self.val_dataset, batch_size=self.batch_size)

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
"""
Create a DataLoader for the test dataset.
Returns
-------
DataLoader
DataLoader containing the test dataset.
"""
return DataLoader(self.test_dataset, batch_size=self.batch_size)
8 changes: 8 additions & 0 deletions modelforge/dataset/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ class QM9Dataset(HDF5Dataset):
>>> data._download()
"""

from modelforge.utils import PropertyNames

_property_names = PropertyNames(
"atomic_numbers",
"geometry",
"return_energy",
)

_available_properties = [
"geometry",
"atomic_numbers",
Expand Down
2 changes: 1 addition & 1 deletion modelforge/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def pad_to_max_length(data: List[np.ndarray]) -> List[np.ndarray]:
max_length = max(len(arr) for arr in data)

return [
np.pad(arr, (0, max_length - len(arr)), "constant", constant_values=-1)
np.pad(arr, (0, max_length - len(arr)), "constant", constant_values=0)
for arr in data
]

Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/__initi__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .schnet import Schnet
from .utils import Dense, GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add
from .utils import GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add
142 changes: 115 additions & 27 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,155 @@
from typing import Dict, List, Optional
from typing import Dict

import lightning as pl
import torch
import torch.nn as nn
from torch.optim import AdamW

from modelforge.utils import Inputs, Properties, SpeciesEnergies
from modelforge.utils import SpeciesEnergies


class BaseNNP(nn.Module):
class PairList(nn.Module):
def __init__(self, cutoff: float = 5.0):
"""
Initialize the PairList class.
"""
super().__init__()
from .utils import neighbor_pairs_nopbc

self.calculate_neighbors = neighbor_pairs_nopbc
self.cutoff = cutoff

def compute_distance(
self, atom_index12: torch.Tensor, R: torch.Tensor
) -> torch.Tensor:
"""
Compute distances based on atom indices and coordinates.
Parameters
----------
atom_index12 : torch.Tensor, shape [n_pairs, 2]
Atom indices for pairs of atoms
R : torch.Tensor, shape [batch_size, n_atoms, n_dims]
Atom coordinates.
Returns
-------
torch.Tensor, shape [n_pairs]
Computed distances.
"""

coordinates = R.flatten(0, 1)
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(
2, -1, 3
)
vec = selected_coordinates[0] - selected_coordinates[1]
return vec.norm(2, -1)

def forward(self, mask, R) -> Dict[str, torch.Tensor]:
atom_index12 = self.calculate_neighbors(mask, R, self.cutoff)
d_ij = self.compute_distance(atom_index12, R)
return {"atom_index12": atom_index12, "d_ij": d_ij}


class BaseNNP(pl.LightningModule):
"""
Abstract base class for neural network potentials.
This class defines the overall structure and ensures that subclasses
implement the `calculate_energies_and_forces` method.
Methods
-------
forward(inputs: dict) -> SpeciesEnergies:
Forward pass for the neural network potential.
calculate_energy(inputs: dict) -> torch.Tensor:
Placeholder for the method that should calculate energies and forces.
training_step(batch, batch_idx) -> torch.Tensor:
Defines the train loop.
configure_optimizers() -> AdamW:
Configures the optimizer.
"""

def __init__(self, dtype: torch.dtype, device: torch.device):
def __init__(self):
"""
Initialize the NeuralNetworkPotential class.
Parameters
----------
dtype : torch.dtype
Data type for the PyTorch tensors.
device : torch.device
Device ("cpu" or "cuda") on which computations will be performed.
Initialize the NNP class.
"""
super().__init__()
self.dtype = dtype
self.device = device

def forward(
self,
inputs: Inputs,
) -> SpeciesEnergies:
def forward(self, inputs: Dict[str, torch.Tensor]) -> SpeciesEnergies:
"""
Forward pass for the neural network potential.
Parameters
----------
inputs : Inputs
An instance of the Inputs data class containing atomic numbers, positions, etc.
inputs : dict
A dictionary containing atomic numbers, positions, etc.
Returns
-------
SpeciesEnergies
An instance of the SpeciesEnergies data class containing species and calculated energies.
"""
assert isinstance(inputs, Dict) #
E = self.calculate_energy(inputs)
return SpeciesEnergies(inputs["Z"], E)

E = self.calculate_energies_and_forces(inputs)
return SpeciesEnergies(inputs.Z, E)

def calculate_energies_and_forces(self, inputs: Optional[Inputs] = None):
def calculate_energy(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Placeholder for the method that should calculate energies and forces.
This method should be implemented in subclasses.
Parameters
----------
inputs : dict
A dictionary containing atomic numbers, positions, etc.
Returns
-------
torch.Tensor
The calculated energy tensor.
Raises
------
NotImplementedError
If the method is not overridden in the subclass.
"""
raise NotImplementedError("Subclasses must implement this method.")

def training_step(
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""
Defines the training loop.
Parameters
----------
batch : dict
Batch data.
batch_idx : int
Batch index.
Returns
-------
torch.Tensor
The loss tensor.
"""

E_hat = self.forward(batch) # wrap_vals_from_dataloader(batch))
loss = nn.functional.mse_loss(E_hat.energies, batch["E"])
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss

def configure_optimizers(self) -> AdamW:
"""
Configures the optimizer for training.
Returns
-------
AdamW
The AdamW optimizer.
"""

optimizer = AdamW(self.parameters(), lr=1e-3)
return optimizer
Loading

0 comments on commit ef7d93d

Please sign in to comment.