Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring changes in the training routine #243

Merged
merged 37 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
389ff8f
small modifications
wiederm Aug 21, 2024
5b75fed
add batchsize
wiederm Aug 21, 2024
2876384
update
wiederm Aug 22, 2024
b152bdf
Merge branch 'main' into ref-training
wiederm Aug 22, 2024
ba118eb
update optimizer, fix bug
wiederm Aug 22, 2024
0bd48b4
Merge branch 'ref-training' of https://github.com/choderalab/modelfor…
wiederm Aug 22, 2024
ae7ee91
use custom torchmetric to sync accross differnt nodes
wiederm Aug 22, 2024
0d61ccb
make logging consistent and log also loss with torchmetric (necessary…
wiederm Aug 23, 2024
a62029c
skip if ANI and SPICE
wiederm Aug 23, 2024
ac9e853
include force training test
wiederm Aug 23, 2024
6ae7ca0
still issues with mutliple GPUs
wiederm Aug 23, 2024
ebba646
make loss tensor's stride contiguous
wiederm Aug 23, 2024
a315724
sync log
wiederm Aug 23, 2024
c5b87be
stride is an issue in the backward pass through the forces, this m mi…
wiederm Aug 24, 2024
8da97be
still stride issue
wiederm Aug 24, 2024
e4ce8d6
add stride hook for backward
wiederm Aug 24, 2024
3f087ab
dicst as module output for stride
wiederm Aug 24, 2024
5072098
avoid saving grad in val/test routine
wiederm Aug 25, 2024
5895aca
only linting changes
wiederm Aug 25, 2024
701122d
fix loss test
wiederm Aug 25, 2024
42dc20a
fix tests
wiederm Aug 26, 2024
14ed13f
Merge branch 'main' into ref-training
wiederm Aug 27, 2024
2c0b643
decorator for method locking
wiederm Aug 27, 2024
992b7bc
lock
wiederm Aug 27, 2024
b5f9888
typo
wiederm Aug 27, 2024
acd06fe
correct lock file mode
wiederm Aug 27, 2024
62003b7
linting
wiederm Aug 27, 2024
e483551
fix test failures
wiederm Aug 27, 2024
c2cfb5f
reasonable defaults for regeneration
wiederm Aug 27, 2024
a1daf2c
ha, didn't think about that
wiederm Aug 27, 2024
2bc8c3c
update parameter name
wiederm Aug 27, 2024
d3539e5
detach metric
wiederm Aug 29, 2024
7c374a5
Having test_train_lightning skip sake+forces when on github CI becaus…
chrisiacovella Aug 31, 2024
2ca2b0a
fixed captitalization error, test should now skip
chrisiacovella Aug 31, 2024
daaedff
in testing I removed the check to see if in github actions. This is r…
chrisiacovella Aug 31, 2024
34a7631
Merge branch 'main' into ref-training
wiederm Aug 31, 2024
4ffc4d5
update weights
wiederm Aug 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,7 @@ lightning_logs/
*.hdf5
*/tb_logs/*
.vscode/settings.json
logs/*
cache/*
*/logs/*
*/cache/*
2 changes: 1 addition & 1 deletion docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Here is an example of a training routine definition:
remove_self_energies = true # Whether to remove self-energies from the dataset
batch_size = 128 # Number of samples per batch
lr = 1e-3 # Learning rate for the optimizer
monitor = "val/per_molecule_energy/rmse" # Metric to monitor for early stopping and checkpointing
monitor_for_checkpoint = "val/per_molecule_energy/rmse" # Metric to monitor for checkpointing


[training.experiment_logger]
Expand Down
68 changes: 51 additions & 17 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, NamedTuple
from typing import TYPE_CHECKING, Dict, List, Literal, NamedTuple, Optional, Union

import numpy as np
import pytorch_lightning as pl
Expand All @@ -15,15 +15,15 @@

from modelforge.dataset.utils import RandomRecordSplittingStrategy, SplittingStrategy
from modelforge.utils.prop import PropertyNames
from modelforge.utils.misc import lock_with_attribute

if TYPE_CHECKING:
from modelforge.potential.processing import AtomicSelfEnergies


from pydantic import BaseModel, field_validator, ConfigDict, Field

from enum import Enum

from pydantic import BaseModel, ConfigDict, Field


class CaseInsensitiveEnum(str, Enum):
@classmethod
Expand Down Expand Up @@ -64,6 +64,7 @@ class DatasetParameters(BaseModel):
version_select: str
num_workers: int = Field(gt=0)
pin_memory: bool
regenerate_processed_cache: bool = False


@dataclass(frozen=False)
Expand Down Expand Up @@ -208,8 +209,9 @@ def as_jax_namedtuple(self) -> NamedTuple:
"""Export the dataclass fields and values as a named tuple.
Convert pytorch tensors to jax arrays."""

from dataclasses import dataclass, fields
import collections
from dataclasses import dataclass, fields

from modelforge.utils.io import import_

convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax
Expand Down Expand Up @@ -237,6 +239,9 @@ def to(
self.metadata = self.metadata.to(device=device, dtype=dtype)
return self

def batch_size(self):
return self.metadata.E.size(dim=0)


class TorchDataset(torch.utils.data.Dataset[BatchData]):
"""
Expand Down Expand Up @@ -1041,7 +1046,6 @@ def create_dataset(
return TorchDataset(data.numpy_data, data._property_names)


from torch import nn
from openff.units import unit


Expand All @@ -1067,6 +1071,7 @@ def __init__(
local_cache_dir: str = "./",
regenerate_cache: bool = False,
regenerate_dataset_statistic: bool = False,
regenerate_processed_cache: bool = True,
):
"""
Initializes adData module for PyTorch Lightning handling data preparation and loading object with the specified configuration.
Expand Down Expand Up @@ -1100,6 +1105,9 @@ def __init__(
regenerate_cache : bool, defaults to False
Whether to regenerate the cache.
"""
from modelforge.potential.models import Pairlist
import os

super().__init__()

self.name = name
Expand All @@ -1116,36 +1124,62 @@ def __init__(
self.train_dataset = None
self.test_dataset = None
self.val_dataset = None
import os

# make sure we can handle a path with a ~ in it
self.local_cache_dir = os.path.expanduser(local_cache_dir)
# create the local cache directory if it does not exist
os.makedirs(self.local_cache_dir, exist_ok=True)
self.regenerate_cache = regenerate_cache
from modelforge.potential.models import Pairlist
# Use a logical OR to ensure regenerate_processed_cache is True when
# regenerate_cache is True
self.regenerate_processed_cache = (
regenerate_processed_cache or self.regenerate_cache
)

self.pairlist = Pairlist()
self.dataset_statistic_filename = (
f"{self.local_cache_dir}/{self.name}_dataset_statistic.toml"
)
self.cache_processed_dataset_filename = (
f"{self.local_cache_dir}/{self.name}_{self.version_select}_processed.pt"
)
self.lock_file = f"{self.cache_processed_dataset_filename}.lockfile"

@lock_with_attribute("lock_file")
def prepare_data(
self,
) -> None:
"""
Prepares the dataset for use. This method is responsible for the initial processing of the data such as calculating self energies, atomic energy statistics, and splitting. It is executed only once per node.
Prepares the dataset for use. This method is responsible for the initial
processing of the data such as calculating self energies, atomic energy
statistics, and splitting. It is executed only once per node.
"""
# check if there is a filelock present, if so, wait until it is removed

# if the dataset has already been processed, skip this step
if (
os.path.exists(self.cache_processed_dataset_filename)
and not self.regenerate_processed_cache
):
if not os.path.exists(self.dataset_statistic_filename):
raise FileNotFoundError(
f"Dataset statistics file {self.dataset_statistic_filename} not found. Please regenerate the cache."
)
log.info('Processed dataset already exists. Skipping "prepare_data" step.')
return None

# if the dataset is not already processed, process it

from modelforge.dataset import _ImplementedDatasets
import toml

dataset_class = _ImplementedDatasets.get_dataset_class(self.name)
dataset_class = _ImplementedDatasets.get_dataset_class(str(self.name))
dataset = dataset_class(
force_download=self.force_download,
version_select=self.version_select,
local_cache_dir=self.local_cache_dir,
regenerate_cache=self.regenerate_cache,
)
torch_dataset = self._create_torch_dataset(dataset)

# if dataset statistics is present load it from disk
if (
os.path.exists(self.dataset_statistic_filename)
Expand Down Expand Up @@ -1283,16 +1317,16 @@ def _calculate_atomic_self_energies(

def _cache_dataset(self, torch_dataset):
"""Cache the dataset and its statistics using PyTorch's serialization."""
torch.save(torch_dataset, "torch_dataset.pt")
# sleep for 1 second to make sure that the dataset was written to disk
torch.save(torch_dataset, self.cache_processed_dataset_filename)
# sleep for 5 second to make sure that the dataset was written to disk
import time

time.sleep(1)
time.sleep(5)

def setup(self, stage: Optional[str] = None) -> None:
"""Sets up datasets for the train, validation, and test stages based on the stage argument."""

self.torch_dataset = torch.load("torch_dataset.pt")
self.torch_dataset = torch.load(self.cache_processed_dataset_filename)
(
self.train_dataset,
self.val_dataset,
Expand Down Expand Up @@ -1342,7 +1376,7 @@ def _per_datapoint_operations(
from tqdm import tqdm

# remove the self energies if requested
log.info("Precalculating pairlist for dataset")
log.info("Performing per datapoint operations in the dataset dataset")
if self.remove_self_energies:
log.info("Removing self energies from the dataset")

Expand Down
28 changes: 4 additions & 24 deletions modelforge/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,35 +96,15 @@ def single_batch(batch_size: int = 64, dataset_name="QM9"):


@pytest.fixture(scope="session")
def single_batch_with_batchsize_64():
def single_batch_with_batchsize():
"""
Utility fixture to create a single batch of data for testing.
"""
return single_batch(batch_size=64)

def _create_single_batch(batch_size: int, dataset_name: str):
return single_batch(batch_size=batch_size, dataset_name=dataset_name)

@pytest.fixture(scope="session")
def single_batch_with_batchsize_1():
"""
Utility fixture to create a single batch of data for testing.
"""
return single_batch(batch_size=1)


@pytest.fixture(scope="session")
def single_batch_with_batchsize_2_with_force():
"""
Utility fixture to create a single batch of data for testing.
"""
return single_batch(batch_size=2, dataset_name="PHALKETHOH")


@pytest.fixture(scope="session")
def single_batch_with_batchsize_16_with_force():
"""
Utility fixture to create a single batch of data for testing.
"""
return single_batch(batch_size=16, dataset_name="PHALKETHOH")
return _create_single_batch


def initialize_dataset(
Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/data/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ number_of_epochs = 2
remove_self_energies = true
batch_size = 128
lr = 1e-3
monitor = "val/per_molecule_energy/rmse"
monitor_for_checkpoint = "val/per_molecule_energy/rmse"

[training.experiment_logger]
logger_name = "tensorboard"
Expand Down
8 changes: 4 additions & 4 deletions modelforge/tests/data/training_defaults/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ number_of_epochs = 2
remove_self_energies = true
batch_size = 128
lr = 1e-3
monitor = "val/per_molecule_energy/rmse"
monitor_for_checkpoint = "val/per_molecule_energy/rmse"


[training.experiment_logger]
Expand Down Expand Up @@ -35,11 +35,11 @@ monitor = "val/per_molecule_energy/rmse"
interval = "epoch"

[training.loss_parameter]
loss_property = ['per_molecule_energy', 'per_atom_force'] # use
loss_property = ['per_molecule_energy'] #, 'per_atom_force'] # use

[training.loss_parameter.weight]
per_molecule_energy = 0.999 #NOTE: reciprocal units
per_atom_force = 0.001
per_molecule_energy = 1.0 #NOTE: reciprocal units
#per_atom_force = 0.001


[training.early_stopping]
Expand Down
54 changes: 54 additions & 0 deletions modelforge/tests/data/training_defaults/default_with_force.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
[training]
number_of_epochs = 2
remove_self_energies = true
batch_size = 128
lr = 1e-3
monitor_for_checkpoint = "val/per_molecule_energy/rmse"


[training.experiment_logger]
logger_name = "tensorboard" # this will set which logger to use

# configuration for both loggers can be defined simultaneously, the logger_name variable defines which logger to use
[training.experiment_logger.tensorboard_configuration]
save_dir = "logs"

[training.experiment_logger.wandb_configuration]
save_dir = "logs"
project = "training_potentials"
group = "exp00"
log_model = true
job_type = "testing"
tags = ["v_0.1.0"]
notes = "testing training"

[training.lr_scheduler]
frequency = 1
mode = "min"
factor = 0.1
patience = 10
cooldown = 5
min_lr = 1e-8
threshold = 0.1
threshold_mode = "abs"
monitor = "val/per_molecule_energy/rmse"
interval = "epoch"

[training.loss_parameter]
loss_property = ['per_molecule_energy', 'per_atom_force'] # use

[training.loss_parameter.weight]
per_molecule_energy = 0.999 #NOTE: reciprocal units
per_atom_force = 0.001


[training.early_stopping]
verbose = true
monitor = "val/per_molecule_energy/rmse"
min_delta = 0.001
patience = 50

[training.splitting_strategy]
name = "random_record_splitting_strategy"
data_split = [0.8, 0.1, 0.1]
seed = 42
8 changes: 5 additions & 3 deletions modelforge/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,11 @@ def test_data_item_format_of_datamodule(
@pytest.mark.parametrize(
"potential_name", _Implemented_NNPs.get_all_neural_network_names()
)
def test_dataset_neighborlist(potential_name, single_batch_with_batchsize_64):
def test_dataset_neighborlist(potential_name, single_batch_with_batchsize):
"""Test the neighborlist."""

nnp_input = single_batch_with_batchsize_64.nnp_input
batch = single_batch_with_batchsize(64, "QM9")
nnp_input = batch.nnp_input

# test that the neighborlist is correctly generated
from modelforge.tests.test_models import load_configs_into_pydantic_models
Expand Down Expand Up @@ -713,7 +714,8 @@ def test_numpy_dataset_assignment(dataset_name):


def test_energy_postprocessing():
# setup test dataset
# test that the mean and stddev of the dataset
# are correct
from modelforge.dataset.dataset import DataModule

# test the self energy calculation on the QM9 dataset
Expand Down
Loading
Loading