Skip to content

Commit

Permalink
Refactoring Schnet and preparing datasets/models for Pytorch Lightnin…
Browse files Browse the repository at this point in the history
…g compatibility (#8)

* first outline of base NNP class

* reference schnet impoementation

* Lighningmodule

* training loop

* pad with 0 for embedding

* updating pooling for batches


* updating docstrings
  • Loading branch information
wiederm authored Sep 11, 2023
1 parent 202c372 commit f38e367
Show file tree
Hide file tree
Showing 16 changed files with 647 additions and 426 deletions.
1 change: 0 additions & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ dependencies:
- torchvision
- openff-units
- pint
- ase

# Testing
- pytest
Expand Down
88 changes: 53 additions & 35 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from loguru import logger
from torch.utils.data import DataLoader

from modelforge.utils.prop import PropertyNames

from .transformation import default_transformation
from .utils import RandomSplittingStrategy, SplittingStrategy

Expand All @@ -20,34 +22,28 @@ class TorchDataset(torch.utils.data.Dataset):
----------
dataset : np.ndarray
The underlying numpy dataset.
prop : List[str]
List of property names to extract from the dataset.
property_name : PropertyNames
Property names to extract from the dataset.
preloaded : bool, optional
If True, preconverts the properties to PyTorch tensors to save time during item fetching.
Default is False.
Examples
--------
>>> numpy_data = np.load("data_file.npz")
>>> properties = ["geometry", "atomic_numbers"]
>>> torch_dataset = TorchDataset(numpy_data, properties)
>>> data_point = torch_dataset[0]
"""

def __init__(
self,
dataset: np.ndarray,
prop: List[str],
property_name: PropertyNames,
preloaded: bool = False,
):
self.properties_of_interest = [dataset[p] for p in prop]
self.length = len(dataset[prop[0]])
self.preloaded = preloaded
self.properties_of_interest = {
"Z": dataset[property_name.Z],
"R": dataset[property_name.R],
"E": dataset[property_name.E],
}

if preloaded:
self.properties_of_interest = [
torch.tensor(p) for p in self.properties_of_interest
]
self.length = len(self.properties_of_interest["Z"])
self.preloaded = preloaded

def __len__(self) -> int:
"""
Expand All @@ -60,7 +56,7 @@ def __len__(self) -> int:
"""
return self.length

def __getitem__(self, idx: int) -> Tuple[torch.Tensor]:
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Fetch a tuple of the values for the properties of interest for a given molecule index.
Expand All @@ -71,20 +67,19 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor]:
Returns
-------
Tuple[torch.Tensor]
Tuple containing tensors for properties of interest of the molecule.
dict, contains:
- 'Z': torch.Tensor, shape [n_atoms]
Atomic numbers for each atom in the molecule.
- 'R': torch.Tensor, shape [n_atoms, 3]
Coordinates for each atom in the molecule.
- 'E': torch.Tensor, shape []
Scalar energy value for the molecule.
Examples
--------
>>> data_point = torch_dataset[5]
>>> geometry, atomic_numbers = data_point
"""
if self.preloaded:
return tuple(prop[idx] for prop in self.properties_of_interest)
else:
return tuple(
torch.tensor(prop[idx]) for prop in self.properties_of_interest
)
Z = torch.tensor(self.properties_of_interest["Z"][idx], dtype=torch.int64)
R = torch.tensor(self.properties_of_interest["R"][idx], dtype=torch.float32)
E = torch.tensor(self.properties_of_interest["E"][idx], dtype=torch.float32)
return {"Z": Z, "R": R, "E": E}


class HDF5Dataset:
Expand Down Expand Up @@ -290,11 +285,10 @@ def create_dataset(

logger.info(f"Creating {data.dataset_name} dataset")
DatasetFactory._load_or_process_data(data, label_transform, transform)
return TorchDataset(data.numpy_data, data.properties_of_interest)
return TorchDataset(data.numpy_data, data._property_names)


class TorchDataModule(pl.LightningDataModule):

"""
A custom data module class to handle data loading and preparation for PyTorch Lightning training.
Expand Down Expand Up @@ -336,7 +330,7 @@ def prepare_data(self) -> None:
factory = DatasetFactory()
self.dataset = factory.create_dataset(self.data)

def setup(self, stage: str):
def setup(self, stage: str) -> None:
"""
Splits the data into training, validation, and test sets based on the stage.
Expand All @@ -355,11 +349,35 @@ def setup(self, stage: str):
_, _, test_dataset = self.SplittingStrategy().split(self.dataset)
self.test_dataset = test_dataset

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
"""
Create a DataLoader for the training dataset.
Returns
-------
DataLoader
DataLoader containing the training dataset.
"""
return DataLoader(self.train_dataset, batch_size=self.batch_size)

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
"""
Create a DataLoader for the validation dataset.
Returns
-------
DataLoader
DataLoader containing the validation dataset.
"""
return DataLoader(self.val_dataset, batch_size=self.batch_size)

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
"""
Create a DataLoader for the test dataset.
Returns
-------
DataLoader
DataLoader containing the test dataset.
"""
return DataLoader(self.test_dataset, batch_size=self.batch_size)
8 changes: 8 additions & 0 deletions modelforge/dataset/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ class QM9Dataset(HDF5Dataset):
>>> data._download()
"""

from modelforge.utils import PropertyNames

_property_names = PropertyNames(
"atomic_numbers",
"geometry",
"return_energy",
)

_available_properties = [
"geometry",
"atomic_numbers",
Expand Down
2 changes: 1 addition & 1 deletion modelforge/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def pad_to_max_length(data: List[np.ndarray]) -> List[np.ndarray]:
max_length = max(len(arr) for arr in data)

return [
np.pad(arr, (0, max_length - len(arr)), "constant", constant_values=-1)
np.pad(arr, (0, max_length - len(arr)), "constant", constant_values=0)
for arr in data
]

Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/__initi__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .schnet import Schnet
from .utils import Dense, GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add
from .utils import GaussianRBF, cosine_cutoff, shifted_softplus, scatter_add
99 changes: 72 additions & 27 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,112 @@
from typing import Dict, List, Optional
from typing import Dict

import lightning as pl
import torch
import torch.nn as nn
from torch.optim import AdamW

from modelforge.utils import Inputs, Properties, SpeciesEnergies
from modelforge.utils import SpeciesEnergies


class BaseNNP(nn.Module):
class BaseNNP(pl.LightningModule):
"""
Abstract base class for neural network potentials.
This class defines the overall structure and ensures that subclasses
implement the `calculate_energies_and_forces` method.
Methods
-------
forward(inputs: dict) -> SpeciesEnergies:
Forward pass for the neural network potential.
calculate_energy(inputs: dict) -> torch.Tensor:
Placeholder for the method that should calculate energies and forces.
training_step(batch, batch_idx) -> torch.Tensor:
Defines the train loop.
configure_optimizers() -> AdamW:
Configures the optimizer.
"""

def __init__(self, dtype: torch.dtype, device: torch.device):
def __init__(self):
"""
Initialize the NeuralNetworkPotential class.
Parameters
----------
dtype : torch.dtype
Data type for the PyTorch tensors.
device : torch.device
Device ("cpu" or "cuda") on which computations will be performed.
Initialize the NNP class.
"""
super().__init__()
self.dtype = dtype
self.device = device

def forward(
self,
inputs: Inputs,
) -> SpeciesEnergies:
def forward(self, inputs: Dict[str, torch.Tensor]) -> SpeciesEnergies:
"""
Forward pass for the neural network potential.
Parameters
----------
inputs : Inputs
An instance of the Inputs data class containing atomic numbers, positions, etc.
inputs : dict
A dictionary containing atomic numbers, positions, etc.
Returns
-------
SpeciesEnergies
An instance of the SpeciesEnergies data class containing species and calculated energies.
"""
assert isinstance(inputs, Dict) #
E = self.calculate_energy(inputs)
return SpeciesEnergies(inputs["Z"], E)

E = self.calculate_energies_and_forces(inputs)
return SpeciesEnergies(inputs.Z, E)

def calculate_energies_and_forces(self, inputs: Optional[Inputs] = None):
def calculate_energy(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Placeholder for the method that should calculate energies and forces.
This method should be implemented in subclasses.
Parameters
----------
inputs : dict
A dictionary containing atomic numbers, positions, etc.
Returns
-------
torch.Tensor
The calculated energy tensor.
Raises
------
NotImplementedError
If the method is not overridden in the subclass.
"""
raise NotImplementedError("Subclasses must implement this method.")

def training_step(
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""
Defines the training loop.
Parameters
----------
batch : dict
Batch data.
batch_idx : int
Batch index.
Returns
-------
torch.Tensor
The loss tensor.
"""

E_hat = self.forward(batch) # wrap_vals_from_dataloader(batch))
loss = nn.functional.mse_loss(E_hat.energies, batch["E"])
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss

def configure_optimizers(self) -> AdamW:
"""
Configures the optimizer for training.
Returns
-------
AdamW
The AdamW optimizer.
"""

optimizer = AdamW(self.parameters(), lr=1e-3)
return optimizer
Loading

0 comments on commit f38e367

Please sign in to comment.