Skip to content

Commit

Permalink
Merge pull request #243 from choderalab/ref-training
Browse files Browse the repository at this point in the history
Refactoring changes in the training routine
  • Loading branch information
wiederm authored Aug 31, 2024
2 parents 5551bea + 4ffc4d5 commit 769d6a8
Show file tree
Hide file tree
Showing 20 changed files with 492 additions and 240 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,7 @@ lightning_logs/
*.hdf5
*/tb_logs/*
.vscode/settings.json
logs/*
cache/*
*/logs/*
*/cache/*
2 changes: 1 addition & 1 deletion docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Here is an example of a training routine definition:
remove_self_energies = true # Whether to remove self-energies from the dataset
batch_size = 128 # Number of samples per batch
lr = 1e-3 # Learning rate for the optimizer
monitor = "val/per_molecule_energy/rmse" # Metric to monitor for early stopping and checkpointing
monitor_for_checkpoint = "val/per_molecule_energy/rmse" # Metric to monitor for checkpointing
[training.experiment_logger]
Expand Down
68 changes: 51 additions & 17 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, NamedTuple
from typing import TYPE_CHECKING, Dict, List, Literal, NamedTuple, Optional, Union

import numpy as np
import pytorch_lightning as pl
Expand All @@ -15,15 +15,15 @@

from modelforge.dataset.utils import RandomRecordSplittingStrategy, SplittingStrategy
from modelforge.utils.prop import PropertyNames
from modelforge.utils.misc import lock_with_attribute

if TYPE_CHECKING:
from modelforge.potential.processing import AtomicSelfEnergies


from pydantic import BaseModel, field_validator, ConfigDict, Field

from enum import Enum

from pydantic import BaseModel, ConfigDict, Field


class CaseInsensitiveEnum(str, Enum):
@classmethod
Expand Down Expand Up @@ -64,6 +64,7 @@ class DatasetParameters(BaseModel):
version_select: str
num_workers: int = Field(gt=0)
pin_memory: bool
regenerate_processed_cache: bool = False


@dataclass(frozen=False)
Expand Down Expand Up @@ -208,8 +209,9 @@ def as_jax_namedtuple(self) -> NamedTuple:
"""Export the dataclass fields and values as a named tuple.
Convert pytorch tensors to jax arrays."""

from dataclasses import dataclass, fields
import collections
from dataclasses import dataclass, fields

from modelforge.utils.io import import_

convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax
Expand Down Expand Up @@ -237,6 +239,9 @@ def to(
self.metadata = self.metadata.to(device=device, dtype=dtype)
return self

def batch_size(self):
return self.metadata.E.size(dim=0)


class TorchDataset(torch.utils.data.Dataset[BatchData]):
"""
Expand Down Expand Up @@ -1041,7 +1046,6 @@ def create_dataset(
return TorchDataset(data.numpy_data, data._property_names)


from torch import nn
from openff.units import unit


Expand All @@ -1067,6 +1071,7 @@ def __init__(
local_cache_dir: str = "./",
regenerate_cache: bool = False,
regenerate_dataset_statistic: bool = False,
regenerate_processed_cache: bool = True,
):
"""
Initializes adData module for PyTorch Lightning handling data preparation and loading object with the specified configuration.
Expand Down Expand Up @@ -1100,6 +1105,9 @@ def __init__(
regenerate_cache : bool, defaults to False
Whether to regenerate the cache.
"""
from modelforge.potential.models import Pairlist
import os

super().__init__()

self.name = name
Expand All @@ -1116,36 +1124,62 @@ def __init__(
self.train_dataset = None
self.test_dataset = None
self.val_dataset = None
import os

# make sure we can handle a path with a ~ in it
self.local_cache_dir = os.path.expanduser(local_cache_dir)
# create the local cache directory if it does not exist
os.makedirs(self.local_cache_dir, exist_ok=True)
self.regenerate_cache = regenerate_cache
from modelforge.potential.models import Pairlist
# Use a logical OR to ensure regenerate_processed_cache is True when
# regenerate_cache is True
self.regenerate_processed_cache = (
regenerate_processed_cache or self.regenerate_cache
)

self.pairlist = Pairlist()
self.dataset_statistic_filename = (
f"{self.local_cache_dir}/{self.name}_dataset_statistic.toml"
)
self.cache_processed_dataset_filename = (
f"{self.local_cache_dir}/{self.name}_{self.version_select}_processed.pt"
)
self.lock_file = f"{self.cache_processed_dataset_filename}.lockfile"

@lock_with_attribute("lock_file")
def prepare_data(
self,
) -> None:
"""
Prepares the dataset for use. This method is responsible for the initial processing of the data such as calculating self energies, atomic energy statistics, and splitting. It is executed only once per node.
Prepares the dataset for use. This method is responsible for the initial
processing of the data such as calculating self energies, atomic energy
statistics, and splitting. It is executed only once per node.
"""
# check if there is a filelock present, if so, wait until it is removed

# if the dataset has already been processed, skip this step
if (
os.path.exists(self.cache_processed_dataset_filename)
and not self.regenerate_processed_cache
):
if not os.path.exists(self.dataset_statistic_filename):
raise FileNotFoundError(
f"Dataset statistics file {self.dataset_statistic_filename} not found. Please regenerate the cache."
)
log.info('Processed dataset already exists. Skipping "prepare_data" step.')
return None

# if the dataset is not already processed, process it

from modelforge.dataset import _ImplementedDatasets
import toml

dataset_class = _ImplementedDatasets.get_dataset_class(self.name)
dataset_class = _ImplementedDatasets.get_dataset_class(str(self.name))
dataset = dataset_class(
force_download=self.force_download,
version_select=self.version_select,
local_cache_dir=self.local_cache_dir,
regenerate_cache=self.regenerate_cache,
)
torch_dataset = self._create_torch_dataset(dataset)

# if dataset statistics is present load it from disk
if (
os.path.exists(self.dataset_statistic_filename)
Expand Down Expand Up @@ -1283,16 +1317,16 @@ def _calculate_atomic_self_energies(

def _cache_dataset(self, torch_dataset):
"""Cache the dataset and its statistics using PyTorch's serialization."""
torch.save(torch_dataset, "torch_dataset.pt")
# sleep for 1 second to make sure that the dataset was written to disk
torch.save(torch_dataset, self.cache_processed_dataset_filename)
# sleep for 5 second to make sure that the dataset was written to disk
import time

time.sleep(1)
time.sleep(5)

def setup(self, stage: Optional[str] = None) -> None:
"""Sets up datasets for the train, validation, and test stages based on the stage argument."""

self.torch_dataset = torch.load("torch_dataset.pt")
self.torch_dataset = torch.load(self.cache_processed_dataset_filename)
(
self.train_dataset,
self.val_dataset,
Expand Down Expand Up @@ -1342,7 +1376,7 @@ def _per_datapoint_operations(
from tqdm import tqdm

# remove the self energies if requested
log.info("Precalculating pairlist for dataset")
log.info("Performing per datapoint operations in the dataset dataset")
if self.remove_self_energies:
log.info("Removing self energies from the dataset")

Expand Down
28 changes: 4 additions & 24 deletions modelforge/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,35 +96,15 @@ def single_batch(batch_size: int = 64, dataset_name="QM9"):


@pytest.fixture(scope="session")
def single_batch_with_batchsize_64():
def single_batch_with_batchsize():
"""
Utility fixture to create a single batch of data for testing.
"""
return single_batch(batch_size=64)

def _create_single_batch(batch_size: int, dataset_name: str):
return single_batch(batch_size=batch_size, dataset_name=dataset_name)

@pytest.fixture(scope="session")
def single_batch_with_batchsize_1():
"""
Utility fixture to create a single batch of data for testing.
"""
return single_batch(batch_size=1)


@pytest.fixture(scope="session")
def single_batch_with_batchsize_2_with_force():
"""
Utility fixture to create a single batch of data for testing.
"""
return single_batch(batch_size=2, dataset_name="PHALKETHOH")


@pytest.fixture(scope="session")
def single_batch_with_batchsize_16_with_force():
"""
Utility fixture to create a single batch of data for testing.
"""
return single_batch(batch_size=16, dataset_name="PHALKETHOH")
return _create_single_batch


def initialize_dataset(
Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/data/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ number_of_epochs = 2
remove_self_energies = true
batch_size = 128
lr = 1e-3
monitor = "val/per_molecule_energy/rmse"
monitor_for_checkpoint = "val/per_molecule_energy/rmse"

[training.experiment_logger]
logger_name = "tensorboard"
Expand Down
8 changes: 4 additions & 4 deletions modelforge/tests/data/training_defaults/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ number_of_epochs = 2
remove_self_energies = true
batch_size = 128
lr = 1e-3
monitor = "val/per_molecule_energy/rmse"
monitor_for_checkpoint = "val/per_molecule_energy/rmse"


[training.experiment_logger]
Expand Down Expand Up @@ -35,11 +35,11 @@ monitor = "val/per_molecule_energy/rmse"
interval = "epoch"

[training.loss_parameter]
loss_property = ['per_molecule_energy', 'per_atom_force'] # use
loss_property = ['per_molecule_energy'] #, 'per_atom_force'] # use

[training.loss_parameter.weight]
per_molecule_energy = 0.999 #NOTE: reciprocal units
per_atom_force = 0.001
per_molecule_energy = 1.0 #NOTE: reciprocal units
#per_atom_force = 0.001


[training.early_stopping]
Expand Down
54 changes: 54 additions & 0 deletions modelforge/tests/data/training_defaults/default_with_force.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
[training]
number_of_epochs = 2
remove_self_energies = true
batch_size = 128
lr = 1e-3
monitor_for_checkpoint = "val/per_molecule_energy/rmse"


[training.experiment_logger]
logger_name = "tensorboard" # this will set which logger to use

# configuration for both loggers can be defined simultaneously, the logger_name variable defines which logger to use
[training.experiment_logger.tensorboard_configuration]
save_dir = "logs"

[training.experiment_logger.wandb_configuration]
save_dir = "logs"
project = "training_potentials"
group = "exp00"
log_model = true
job_type = "testing"
tags = ["v_0.1.0"]
notes = "testing training"

[training.lr_scheduler]
frequency = 1
mode = "min"
factor = 0.1
patience = 10
cooldown = 5
min_lr = 1e-8
threshold = 0.1
threshold_mode = "abs"
monitor = "val/per_molecule_energy/rmse"
interval = "epoch"

[training.loss_parameter]
loss_property = ['per_molecule_energy', 'per_atom_force'] # use

[training.loss_parameter.weight]
per_molecule_energy = 0.999 #NOTE: reciprocal units
per_atom_force = 0.001


[training.early_stopping]
verbose = true
monitor = "val/per_molecule_energy/rmse"
min_delta = 0.001
patience = 50

[training.splitting_strategy]
name = "random_record_splitting_strategy"
data_split = [0.8, 0.1, 0.1]
seed = 42
8 changes: 5 additions & 3 deletions modelforge/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,11 @@ def test_data_item_format_of_datamodule(
@pytest.mark.parametrize(
"potential_name", _Implemented_NNPs.get_all_neural_network_names()
)
def test_dataset_neighborlist(potential_name, single_batch_with_batchsize_64):
def test_dataset_neighborlist(potential_name, single_batch_with_batchsize):
"""Test the neighborlist."""

nnp_input = single_batch_with_batchsize_64.nnp_input
batch = single_batch_with_batchsize(64, "QM9")
nnp_input = batch.nnp_input

# test that the neighborlist is correctly generated
from modelforge.tests.test_models import load_configs_into_pydantic_models
Expand Down Expand Up @@ -713,7 +714,8 @@ def test_numpy_dataset_assignment(dataset_name):


def test_energy_postprocessing():
# setup test dataset
# test that the mean and stddev of the dataset
# are correct
from modelforge.dataset.dataset import DataModule

# test the self energy calculation on the QM9 dataset
Expand Down
Loading

0 comments on commit 769d6a8

Please sign in to comment.