diff --git a/.gitignore b/.gitignore index 2ac2cee4..3c452056 100644 --- a/.gitignore +++ b/.gitignore @@ -190,3 +190,7 @@ lightning_logs/ *.hdf5 */tb_logs/* .vscode/settings.json +logs/* +cache/* +*/logs/* +*/cache/* diff --git a/docs/getting_started.rst b/docs/getting_started.rst index c533e47f..0dc1b6bd 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -149,7 +149,7 @@ Here is an example of a training routine definition: remove_self_energies = true # Whether to remove self-energies from the dataset batch_size = 128 # Number of samples per batch lr = 1e-3 # Learning rate for the optimizer - monitor = "val/per_molecule_energy/rmse" # Metric to monitor for early stopping and checkpointing + monitor_for_checkpoint = "val/per_molecule_energy/rmse" # Metric to monitor for checkpointing [training.experiment_logger] diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 07a80f69..b388b317 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -4,7 +4,7 @@ import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, NamedTuple +from typing import TYPE_CHECKING, Dict, List, Literal, NamedTuple, Optional, Union import numpy as np import pytorch_lightning as pl @@ -15,15 +15,15 @@ from modelforge.dataset.utils import RandomRecordSplittingStrategy, SplittingStrategy from modelforge.utils.prop import PropertyNames +from modelforge.utils.misc import lock_with_attribute if TYPE_CHECKING: from modelforge.potential.processing import AtomicSelfEnergies - -from pydantic import BaseModel, field_validator, ConfigDict, Field - from enum import Enum +from pydantic import BaseModel, ConfigDict, Field + class CaseInsensitiveEnum(str, Enum): @classmethod @@ -64,6 +64,7 @@ class DatasetParameters(BaseModel): version_select: str num_workers: int = Field(gt=0) pin_memory: bool + regenerate_processed_cache: bool = False @dataclass(frozen=False) @@ -208,8 +209,9 @@ def as_jax_namedtuple(self) -> NamedTuple: """Export the dataclass fields and values as a named tuple. Convert pytorch tensors to jax arrays.""" - from dataclasses import dataclass, fields import collections + from dataclasses import dataclass, fields + from modelforge.utils.io import import_ convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax @@ -237,6 +239,9 @@ def to( self.metadata = self.metadata.to(device=device, dtype=dtype) return self + def batch_size(self): + return self.metadata.E.size(dim=0) + class TorchDataset(torch.utils.data.Dataset[BatchData]): """ @@ -1041,7 +1046,6 @@ def create_dataset( return TorchDataset(data.numpy_data, data._property_names) -from torch import nn from openff.units import unit @@ -1067,6 +1071,7 @@ def __init__( local_cache_dir: str = "./", regenerate_cache: bool = False, regenerate_dataset_statistic: bool = False, + regenerate_processed_cache: bool = True, ): """ Initializes adData module for PyTorch Lightning handling data preparation and loading object with the specified configuration. @@ -1100,6 +1105,9 @@ def __init__( regenerate_cache : bool, defaults to False Whether to regenerate the cache. """ + from modelforge.potential.models import Pairlist + import os + super().__init__() self.name = name @@ -1116,28 +1124,55 @@ def __init__( self.train_dataset = None self.test_dataset = None self.val_dataset = None - import os # make sure we can handle a path with a ~ in it self.local_cache_dir = os.path.expanduser(local_cache_dir) + # create the local cache directory if it does not exist + os.makedirs(self.local_cache_dir, exist_ok=True) self.regenerate_cache = regenerate_cache - from modelforge.potential.models import Pairlist + # Use a logical OR to ensure regenerate_processed_cache is True when + # regenerate_cache is True + self.regenerate_processed_cache = ( + regenerate_processed_cache or self.regenerate_cache + ) self.pairlist = Pairlist() self.dataset_statistic_filename = ( f"{self.local_cache_dir}/{self.name}_dataset_statistic.toml" ) + self.cache_processed_dataset_filename = ( + f"{self.local_cache_dir}/{self.name}_{self.version_select}_processed.pt" + ) + self.lock_file = f"{self.cache_processed_dataset_filename}.lockfile" + @lock_with_attribute("lock_file") def prepare_data( self, ) -> None: """ - Prepares the dataset for use. This method is responsible for the initial processing of the data such as calculating self energies, atomic energy statistics, and splitting. It is executed only once per node. + Prepares the dataset for use. This method is responsible for the initial + processing of the data such as calculating self energies, atomic energy + statistics, and splitting. It is executed only once per node. """ + # check if there is a filelock present, if so, wait until it is removed + + # if the dataset has already been processed, skip this step + if ( + os.path.exists(self.cache_processed_dataset_filename) + and not self.regenerate_processed_cache + ): + if not os.path.exists(self.dataset_statistic_filename): + raise FileNotFoundError( + f"Dataset statistics file {self.dataset_statistic_filename} not found. Please regenerate the cache." + ) + log.info('Processed dataset already exists. Skipping "prepare_data" step.') + return None + + # if the dataset is not already processed, process it + from modelforge.dataset import _ImplementedDatasets - import toml - dataset_class = _ImplementedDatasets.get_dataset_class(self.name) + dataset_class = _ImplementedDatasets.get_dataset_class(str(self.name)) dataset = dataset_class( force_download=self.force_download, version_select=self.version_select, @@ -1145,7 +1180,6 @@ def prepare_data( regenerate_cache=self.regenerate_cache, ) torch_dataset = self._create_torch_dataset(dataset) - # if dataset statistics is present load it from disk if ( os.path.exists(self.dataset_statistic_filename) @@ -1283,16 +1317,16 @@ def _calculate_atomic_self_energies( def _cache_dataset(self, torch_dataset): """Cache the dataset and its statistics using PyTorch's serialization.""" - torch.save(torch_dataset, "torch_dataset.pt") - # sleep for 1 second to make sure that the dataset was written to disk + torch.save(torch_dataset, self.cache_processed_dataset_filename) + # sleep for 5 second to make sure that the dataset was written to disk import time - time.sleep(1) + time.sleep(5) def setup(self, stage: Optional[str] = None) -> None: """Sets up datasets for the train, validation, and test stages based on the stage argument.""" - self.torch_dataset = torch.load("torch_dataset.pt") + self.torch_dataset = torch.load(self.cache_processed_dataset_filename) ( self.train_dataset, self.val_dataset, @@ -1342,7 +1376,7 @@ def _per_datapoint_operations( from tqdm import tqdm # remove the self energies if requested - log.info("Precalculating pairlist for dataset") + log.info("Performing per datapoint operations in the dataset dataset") if self.remove_self_energies: log.info("Removing self energies from the dataset") diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index 7e8e475f..f7ccd939 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -96,35 +96,15 @@ def single_batch(batch_size: int = 64, dataset_name="QM9"): @pytest.fixture(scope="session") -def single_batch_with_batchsize_64(): +def single_batch_with_batchsize(): """ Utility fixture to create a single batch of data for testing. """ - return single_batch(batch_size=64) + def _create_single_batch(batch_size: int, dataset_name: str): + return single_batch(batch_size=batch_size, dataset_name=dataset_name) -@pytest.fixture(scope="session") -def single_batch_with_batchsize_1(): - """ - Utility fixture to create a single batch of data for testing. - """ - return single_batch(batch_size=1) - - -@pytest.fixture(scope="session") -def single_batch_with_batchsize_2_with_force(): - """ - Utility fixture to create a single batch of data for testing. - """ - return single_batch(batch_size=2, dataset_name="PHALKETHOH") - - -@pytest.fixture(scope="session") -def single_batch_with_batchsize_16_with_force(): - """ - Utility fixture to create a single batch of data for testing. - """ - return single_batch(batch_size=16, dataset_name="PHALKETHOH") + return _create_single_batch def initialize_dataset( diff --git a/modelforge/tests/data/config.toml b/modelforge/tests/data/config.toml index c323c354..53e408cf 100644 --- a/modelforge/tests/data/config.toml +++ b/modelforge/tests/data/config.toml @@ -35,7 +35,7 @@ number_of_epochs = 2 remove_self_energies = true batch_size = 128 lr = 1e-3 -monitor = "val/per_molecule_energy/rmse" +monitor_for_checkpoint = "val/per_molecule_energy/rmse" [training.experiment_logger] logger_name = "tensorboard" diff --git a/modelforge/tests/data/training_defaults/default.toml b/modelforge/tests/data/training_defaults/default.toml index 19c0d3b7..c8d96fc2 100644 --- a/modelforge/tests/data/training_defaults/default.toml +++ b/modelforge/tests/data/training_defaults/default.toml @@ -3,7 +3,7 @@ number_of_epochs = 2 remove_self_energies = true batch_size = 128 lr = 1e-3 -monitor = "val/per_molecule_energy/rmse" +monitor_for_checkpoint = "val/per_molecule_energy/rmse" [training.experiment_logger] @@ -35,11 +35,11 @@ monitor = "val/per_molecule_energy/rmse" interval = "epoch" [training.loss_parameter] -loss_property = ['per_molecule_energy', 'per_atom_force'] # use +loss_property = ['per_molecule_energy'] #, 'per_atom_force'] # use [training.loss_parameter.weight] -per_molecule_energy = 0.999 #NOTE: reciprocal units -per_atom_force = 0.001 +per_molecule_energy = 1.0 #NOTE: reciprocal units +#per_atom_force = 0.001 [training.early_stopping] diff --git a/modelforge/tests/data/training_defaults/default_with_force.toml b/modelforge/tests/data/training_defaults/default_with_force.toml new file mode 100644 index 00000000..4ba07f8b --- /dev/null +++ b/modelforge/tests/data/training_defaults/default_with_force.toml @@ -0,0 +1,54 @@ +[training] +number_of_epochs = 2 +remove_self_energies = true +batch_size = 128 +lr = 1e-3 +monitor_for_checkpoint = "val/per_molecule_energy/rmse" + + +[training.experiment_logger] +logger_name = "tensorboard" # this will set which logger to use + +# configuration for both loggers can be defined simultaneously, the logger_name variable defines which logger to use +[training.experiment_logger.tensorboard_configuration] +save_dir = "logs" + +[training.experiment_logger.wandb_configuration] +save_dir = "logs" +project = "training_potentials" +group = "exp00" +log_model = true +job_type = "testing" +tags = ["v_0.1.0"] +notes = "testing training" + +[training.lr_scheduler] +frequency = 1 +mode = "min" +factor = 0.1 +patience = 10 +cooldown = 5 +min_lr = 1e-8 +threshold = 0.1 +threshold_mode = "abs" +monitor = "val/per_molecule_energy/rmse" +interval = "epoch" + +[training.loss_parameter] +loss_property = ['per_molecule_energy', 'per_atom_force'] # use + +[training.loss_parameter.weight] +per_molecule_energy = 0.999 #NOTE: reciprocal units +per_atom_force = 0.001 + + +[training.early_stopping] +verbose = true +monitor = "val/per_molecule_energy/rmse" +min_delta = 0.001 +patience = 50 + +[training.splitting_strategy] +name = "random_record_splitting_strategy" +data_split = [0.8, 0.1, 0.1] +seed = 42 diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index 7702892b..91f963a3 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -458,10 +458,11 @@ def test_data_item_format_of_datamodule( @pytest.mark.parametrize( "potential_name", _Implemented_NNPs.get_all_neural_network_names() ) -def test_dataset_neighborlist(potential_name, single_batch_with_batchsize_64): +def test_dataset_neighborlist(potential_name, single_batch_with_batchsize): """Test the neighborlist.""" - nnp_input = single_batch_with_batchsize_64.nnp_input + batch = single_batch_with_batchsize(64, "QM9") + nnp_input = batch.nnp_input # test that the neighborlist is correctly generated from modelforge.tests.test_models import load_configs_into_pydantic_models @@ -713,7 +714,8 @@ def test_numpy_dataset_assignment(dataset_name): def test_energy_postprocessing(): - # setup test dataset + # test that the mean and stddev of the dataset + # are correct from modelforge.dataset.dataset import DataModule # test the self energy calculation on the QM9 dataset diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 1faecbf6..4c9889e3 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -54,11 +54,13 @@ def load_configs_into_pydantic_models(potential_name: str, dataset_name: str): @pytest.mark.parametrize( "potential_name", _Implemented_NNPs.get_all_neural_network_names() ) -def test_JAX_wrapping(potential_name, single_batch_with_batchsize_64): +def test_JAX_wrapping(potential_name, single_batch_with_batchsize): from modelforge.potential.models import ( NeuralNetworkPotentialFactory, ) + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + # read default parameters config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") @@ -70,7 +72,7 @@ def test_JAX_wrapping(potential_name, single_batch_with_batchsize_64): ) assert "JAX" in str(type(model)) - nnp_input = single_batch_with_batchsize_64.nnp_input.as_jax_namedtuple() + nnp_input = batch.nnp_input.as_jax_namedtuple() out = model(nnp_input)["per_molecule_energy"] import jax @@ -329,13 +331,14 @@ def test_dataset_statistic(potential_name): "potential_name", _Implemented_NNPs.get_all_neural_network_names() ) def test_energy_between_simulation_environments( - potential_name, single_batch_with_batchsize_64 + potential_name, single_batch_with_batchsize ): # compare that the energy is the same for the JAX and PyTorch Model import numpy as np import torch - nnp_input = single_batch_with_batchsize_64.nnp_input + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + nnp_input = batch.nnp_input # test the forward pass through each of the models # cast input and model to torch.float64 # read default parameters @@ -416,20 +419,24 @@ def test_forward_pass_with_all_datasets( assert torch.all(pair_list[0, 1:] >= pair_list[0, :-1]) +@pytest.mark.parametrize("dataset_name", ["QM9", "SPICE2"]) @pytest.mark.parametrize( "potential_name", _Implemented_NNPs.get_all_neural_network_names() ) @pytest.mark.parametrize("simulation_environment", ["JAX", "PyTorch"]) def test_forward_pass( - potential_name, simulation_environment, single_batch_with_batchsize_64 + dataset_name, potential_name, simulation_environment, single_batch_with_batchsize ): # this test sends a single batch from different datasets through the model import torch - nnp_input = single_batch_with_batchsize_64.nnp_input + batch = batch = single_batch_with_batchsize(batch_size=6, dataset_name=dataset_name) + nnp_input = batch.nnp_input # read default parameters - config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") + config = load_configs_into_pydantic_models( + f"{potential_name.lower()}", dataset_name + ) nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] # test the forward pass through each of the models @@ -450,8 +457,13 @@ def test_forward_pass( # which have chemically equivalent hydrogens at the minimum geometry. # This has to be reflected in the atomic energies E_i, which # have to be equal for all hydrogens - if "JAX" not in str(type(model)): - from loguru import logger as log + if "JAX" not in str(type(model)) and dataset_name == "QM9": + # make sure that we are correctly reducing + ref = torch.zeros_like(output["per_molecule_energy"]).scatter_add_( + 0, nnp_input.atomic_subsystem_indices.long(), output["per_atom_energy"] + ) + + assert torch.allclose(ref, output["per_molecule_energy"]) # assert that the following tensor has equal values for dim=0 index 1 to 4 and 6 to 8 @@ -478,17 +490,18 @@ def test_forward_pass( @pytest.mark.parametrize( "potential_name", _Implemented_NNPs.get_all_neural_network_names() ) -def test_calculate_energies_and_forces(potential_name, single_batch_with_batchsize_64): +def test_calculate_energies_and_forces(potential_name, single_batch_with_batchsize): """ Test the calculation of energies and forces for a molecule. """ import torch + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") # read default parameters config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") # get batch - nnp_input = single_batch_with_batchsize_64.nnp_input + nnp_input = batch.nnp_input # test the pass through each of the models model_training = NeuralNetworkPotentialFactory.generate_potential( @@ -536,7 +549,7 @@ def test_calculate_energies_and_forces(potential_name, single_batch_with_batchsi "potential_name", _Implemented_NNPs.get_all_neural_network_names() ) def test_calculate_energies_and_forces_with_jax( - potential_name, single_batch_with_batchsize_64 + potential_name, single_batch_with_batchsize ): """ Test the calculation of energies and forces for a molecule. @@ -545,8 +558,8 @@ def test_calculate_energies_and_forces_with_jax( # read default parameters config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") - - nnp_input = single_batch_with_batchsize_64.nnp_input + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + nnp_input = batch.nnp_input # test the backward pass through each of the models nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] nr_of_atoms_per_batch = nnp_input.atomic_subsystem_indices.shape[0] @@ -976,11 +989,11 @@ def test_pairlist_on_dataset(): @pytest.mark.parametrize( "potential_name", _Implemented_NNPs.get_all_neural_network_names() ) -def test_casting(potential_name, single_batch_with_batchsize_64): +def test_casting(potential_name, single_batch_with_batchsize): # test dtype casting import torch - batch = single_batch_with_batchsize_64 + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") batch_ = batch.to(dtype=torch.float64) assert batch_.nnp_input.positions.dtype == torch.float64 batch_ = batch_.to(dtype=torch.float32) @@ -1025,7 +1038,7 @@ def test_casting(potential_name, single_batch_with_batchsize_64): def test_equivariant_energies_and_forces( potential_name, simulation_environment, - single_batch_with_batchsize_64, + single_batch_with_batchsize, equivariance_utils, ): """ @@ -1048,7 +1061,8 @@ def test_equivariant_energies_and_forces( translation, rotation, reflection = equivariance_utils # define the tolerance atol = 1e-3 - nnp_input = single_batch_with_batchsize_64.nnp_input + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + nnp_input = batch.nnp_input # initialize the models model = model.to(dtype=torch.float64) @@ -1056,7 +1070,9 @@ def test_equivariant_energies_and_forces( # ------------------- # # start the test # reference values - nnp_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float64) + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + + nnp_input = batch.nnp_input.to(dtype=torch.float64) reference_result = model(nnp_input)["per_molecule_energy"].to(dtype=torch.float64) reference_forces = -torch.autograd.grad( reference_result.sum(), diff --git a/modelforge/tests/test_nn.py b/modelforge/tests/test_nn.py index e399550a..a5b0b276 100644 --- a/modelforge/tests/test_nn.py +++ b/modelforge/tests/test_nn.py @@ -1,14 +1,16 @@ from .test_models import load_configs_into_pydantic_models -def test_embedding(single_batch_with_batchsize_64): +def test_embedding(single_batch_with_batchsize): # test the input featurization, including: # - nuclear charge embedding # - total charge mixing - import torch + import torch # noqa: F401 - nnp_input = single_batch_with_batchsize_64.nnp_input + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + + nnp_input = batch.nnp_input model_name = "SchNet" # read default parameters and extract featurization config = load_configs_into_pydantic_models(f"{model_name.lower()}", "qm9") diff --git a/modelforge/tests/test_painn.py b/modelforge/tests/test_painn.py index 647d9901..3dc60d48 100644 --- a/modelforge/tests/test_painn.py +++ b/modelforge/tests/test_painn.py @@ -2,7 +2,7 @@ from modelforge.potential.painn import PaiNN -def test_forward(single_batch_with_batchsize_64): +def test_forward(single_batch_with_batchsize): """Test initialization of the PaiNN neural network potential.""" # read default parameters from modelforge.tests.test_models import load_configs_into_pydantic_models @@ -17,8 +17,9 @@ def test_forward(single_batch_with_batchsize_64): ], ) assert painn is not None, "PaiNN model should be initialized." + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") - nnp_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float32) + nnp_input = batch.nnp_input.to(dtype=torch.float32) energy = painn(nnp_input)["per_molecule_energy"] nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] @@ -27,11 +28,13 @@ def test_forward(single_batch_with_batchsize_64): ) # Assuming energy is calculated per sample in the batch -def test_equivariance(single_batch_with_batchsize_64): +def test_equivariance(single_batch_with_batchsize): from modelforge.potential.painn import PaiNN from dataclasses import replace import torch + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + from modelforge.tests.test_models import load_configs_into_pydantic_models # read default parameters @@ -50,7 +53,7 @@ def test_equivariance(single_batch_with_batchsize_64): ], ).double() - methane_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float64) + methane_input = batch.nnp_input.to(dtype=torch.float64) perturbed_methane_input = replace(methane_input) perturbed_methane_input.positions = torch.matmul( methane_input.positions, rotation_matrix diff --git a/modelforge/tests/test_parameter_models.py b/modelforge/tests/test_parameter_models.py index ce97537d..2c26e2cd 100644 --- a/modelforge/tests/test_parameter_models.py +++ b/modelforge/tests/test_parameter_models.py @@ -142,6 +142,9 @@ def test_training_parameter_model(): with pytest.raises(ValidationError): training_parameters.splitting_strategy.dataset_split = [0.7, 0.1, 0.1, 0.1] - # this will throw an error because the datafile has 2 entries for the loss_property dictionary + # this will throw an error because the datafile has 1 entries for the loss_property dictionary with pytest.raises(ValidationError): - training_parameters.loss_parameter.loss_property = ["per_molecule_energy"] + training_parameters.loss_parameter.loss_property = [ + "per_molecule_energy", + "per_atom_force", + ] diff --git a/modelforge/tests/test_physnet.py b/modelforge/tests/test_physnet.py index 8aadca05..1601654d 100644 --- a/modelforge/tests/test_physnet.py +++ b/modelforge/tests/test_physnet.py @@ -15,7 +15,7 @@ def test_init(): ) -def test_forward(single_batch_with_batchsize_64): +def test_forward(single_batch_with_batchsize): import torch from modelforge.potential.physnet import PhysNet @@ -37,7 +37,9 @@ def test_forward(single_batch_with_batchsize_64): ) model = model.to(torch.float32) print(model) - yhat = model(single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float32)) + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + + yhat = model(batch.nnp_input.to(dtype=torch.float32)) def test_compare_representation(): diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 8fff65ba..994b7112 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -33,12 +33,13 @@ def test_init(): from openff.units import unit -def test_forward(single_batch_with_batchsize_64): +def test_forward(single_batch_with_batchsize): """ Test the forward pass of the SAKE model. """ # get methane input - methane = single_batch_with_batchsize_64.nnp_input + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + methane = batch.nnp_input from modelforge.tests.test_models import load_configs_into_pydantic_models @@ -91,7 +92,7 @@ def test_interaction_forward(): @pytest.mark.parametrize("eq_atol", [3e-1]) @pytest.mark.parametrize("h_atol", [8e-2]) -def test_layer_equivariance(h_atol, eq_atol, single_batch_with_batchsize_64): +def test_layer_equivariance(h_atol, eq_atol, single_batch_with_batchsize): import torch from modelforge.potential.sake import SAKE from dataclasses import replace @@ -118,7 +119,9 @@ def test_layer_equivariance(h_atol, eq_atol, single_batch_with_batchsize_64): ) # get methane input - methane = single_batch_with_batchsize_64.nnp_input + batch = batch = single_batch_with_batchsize(batch_size=64, dataset_name="QM9") + + methane = batch.nnp_input perturbed_methane_input = replace(methane) perturbed_methane_input.positions = torch.matmul(methane.positions, rotation_matrix) @@ -424,7 +427,7 @@ def test_sake_layer_against_reference(include_self_pairs, v_is_none): # FIXME: this test is currently failing @pytest.mark.xfail -def test_model_against_reference(single_batch_with_batchsize_1): +def test_model_against_reference(single_batch_with_batchsize): nr_heads = 5 key = jax.random.PRNGKey(1884) torch.manual_seed(1884) @@ -462,7 +465,8 @@ def test_model_against_reference(single_batch_with_batchsize_1): ) # get methane input - methane = single_batch_with_batchsize_1.nnp_input + batch = single_batch_with_batchsize(batch_size=1) + methane = batch.nnp_input pairlist_output = mf_sake.compute_interacting_pairs.prepare_inputs(methane) prepared_methane = mf_sake.core_module._model_specific_input_preparation( methane, pairlist_output @@ -605,7 +609,7 @@ def test_model_against_reference(single_batch_with_batchsize_1): # assert torch.allclose(mf_out.E, torch.from_numpy(onp.array(ref_out[0]))) -def test_model_invariance(single_batch_with_batchsize_1): +def test_model_invariance(single_batch_with_batchsize): from dataclasses import replace from modelforge.tests.test_models import load_configs_into_pydantic_models @@ -620,7 +624,8 @@ def test_model_invariance(single_batch_with_batchsize_1): ], ) # get methane input - methane = single_batch_with_batchsize_1.nnp_input + batch = single_batch_with_batchsize(batch_size=1, dataset_name="QM9") + methane = batch.nnp_input rotation_matrix = torch.tensor([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) perturbed_methane_input = replace(methane) diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py index 3b514847..72ee5e15 100644 --- a/modelforge/tests/test_training.py +++ b/modelforge/tests/test_training.py @@ -10,7 +10,9 @@ from modelforge.potential import NeuralNetworkPotentialFactory, _Implemented_NNPs -def load_configs_into_pydantic_models(potential_name: str, dataset_name: str): +def load_configs_into_pydantic_models( + potential_name: str, dataset_name: str, training_toml: str +): from importlib import resources import toml @@ -26,7 +28,7 @@ def load_configs_into_pydantic_models(potential_name: str, dataset_name: str): resources.files(potential_defaults) / f"{potential_name.lower()}.toml" ) dataset_path = resources.files(dataset_defaults) / f"{dataset_name.lower()}.toml" - training_path = resources.files(training_defaults) / "default.toml" + training_path = resources.files(training_defaults) / f"{training_toml}.toml" runtime_path = resources.files(runtime_defaults) / "runtime.toml" training_config_dict = toml.load(training_path) @@ -58,8 +60,10 @@ def load_configs_into_pydantic_models(potential_name: str, dataset_name: str): } -def get_trainer(potential_name: str, dataset_name: str): - config = load_configs_into_pydantic_models(potential_name, dataset_name) +def get_trainer(potential_name: str, dataset_name: str, training_toml: str): + config = load_configs_into_pydantic_models( + potential_name, dataset_name, training_toml + ) # Extract parameters potential_parameter = config["potential"] @@ -80,18 +84,30 @@ def get_trainer(potential_name: str, dataset_name: str): @pytest.mark.parametrize( "potential_name", _Implemented_NNPs.get_all_neural_network_names() ) -@pytest.mark.parametrize("dataset_name", ["QM9"]) -def test_train_with_lightning(potential_name, dataset_name): +@pytest.mark.parametrize("dataset_name", ["PHALKETHOH"]) +@pytest.mark.parametrize("training", ["with_force", "without_force"]) +def test_train_with_lightning(training, potential_name, dataset_name): """ Test that we can train, save and load checkpoints. """ - - get_trainer(potential_name, dataset_name).train_potential().save_checkpoint( + # get correct training toml + training_toml = "default_with_force" if training == "with_force" else "default" + # SKIP if potential is ANI and dataset is SPICE2 + if "ANI" in potential_name and dataset_name == "SPICE2": + pytest.skip("ANI potential is not compatible with SPICE2 dataset") + if IN_GITHUB_ACTIONS and potential_name == "SAKE" and training == "with_force": + pytest.skip( + "Skipping Sake training with forces on GitHub Actions because it allocates too much memory" + ) + # train potential + get_trainer( + potential_name, dataset_name, training_toml + ).train_potential().save_checkpoint( "test.chp" ) # save checkpoint # continue training from checkpoint - get_trainer(potential_name, dataset_name).train_potential() + get_trainer(potential_name, dataset_name, training_toml).train_potential() def test_train_from_single_toml_file(): @@ -105,15 +121,17 @@ def test_train_from_single_toml_file(): read_config_and_train(config_path) -def test_error_calculation(single_batch_with_batchsize_16_with_force): +def test_error_calculation(single_batch_with_batchsize): # test the different Loss classes from modelforge.train.training import ( - FromPerAtomToPerMoleculeMeanSquaredError, - PerMoleculeMeanSquaredError, + FromPerAtomToPerMoleculeSquaredError, + PerMoleculeSquaredError, ) # generate data - data = single_batch_with_batchsize_16_with_force + batch = single_batch_with_batchsize(batch_size=16, dataset_name="PHALKETHOH") + + data = batch true_E = data.metadata.E true_F = data.metadata.F @@ -122,7 +140,7 @@ def test_error_calculation(single_batch_with_batchsize_16_with_force): predicted_F = true_F + torch.rand_like(true_F) * 10 # test error for property with shape (nr_of_molecules, 1) - error = PerMoleculeMeanSquaredError() + error = PerMoleculeSquaredError() E_error = error(predicted_E, true_E, data) # compare output (mean squared error scaled by number of atoms in the molecule) @@ -132,10 +150,10 @@ def test_error_calculation(single_batch_with_batchsize_16_with_force): 1 ) # FIXME : fi reference_E_error = torch.mean(scale_squared_error) - assert torch.allclose(E_error, reference_E_error) + assert torch.allclose(torch.mean(E_error), reference_E_error) # test error for property with shape (nr_of_atoms, 3) - error = FromPerAtomToPerMoleculeMeanSquaredError() + error = FromPerAtomToPerMoleculeSquaredError() F_error = error(predicted_F, true_F, data) # compare error (mean squared error scaled by number of atoms in the molecule) @@ -155,13 +173,13 @@ def test_error_calculation(single_batch_with_batchsize_16_with_force): reference_F_error = torch.mean( per_mol_error / (3 * data.metadata.atomic_subsystem_counts.unsqueeze(1)) ) - assert torch.allclose(F_error, reference_F_error) + assert torch.allclose(torch.mean(F_error), reference_F_error) -def test_loss(single_batch_with_batchsize_16_with_force): +def test_loss(single_batch_with_batchsize): from modelforge.train.training import Loss - batch = single_batch_with_batchsize_16_with_force + batch = single_batch_with_batchsize(batch_size=16, dataset_name="PHALKETHOH") loss_porperty = ["per_molecule_energy", "per_atom_force"] loss_weights = {"per_molecule_energy": 0.5, "per_atom_force": 0.5} @@ -169,8 +187,17 @@ def test_loss(single_batch_with_batchsize_16_with_force): assert loss is not None # get trainer - trainer = get_trainer("schnet", "QM9") - prediction = trainer.model.calculate_predictions(batch, trainer.model.potential) + trainer = get_trainer("schnet", "QM9", "default_with_force") + prediction = trainer.model.calculate_predictions( + batch, trainer.model.potential, train_mode=True + ) # train_mode=True is required for gradients in force prediction + + assert prediction["per_molecule_energy_predict"].size( + dim=0 + ) == batch.metadata.E.size(dim=0) + assert prediction["per_atom_force_predict"].size(dim=0) == batch.metadata.F.size( + dim=0 + ) # pass prediction through loss module loss_output = loss(prediction, batch) @@ -188,9 +215,12 @@ def test_loss(single_batch_with_batchsize_16_with_force): prediction["per_molecule_energy_predict"] - prediction["per_molecule_energy_true"] ).pow(2) + / batch.metadata.atomic_subsystem_counts.unsqueeze(1) ) ) - assert torch.allclose(loss_output["per_molecule_energy/mse"], E_loss) + # compare to referenc evalue obtained from Loos class + ref = torch.mean(loss_output["per_molecule_energy"]) + assert torch.allclose(ref, E_loss) # --------------------------------------------- # # now calculate F_loss @@ -215,15 +245,15 @@ def test_loss(single_batch_with_batchsize_16_with_force): ) per_atom_force_mse = torch.mean(per_molecule_squared_error) - assert torch.allclose(loss_output["per_atom_force/mse"], per_atom_force_mse) + assert torch.allclose(torch.mean(loss_output["per_atom_force"]), per_atom_force_mse) # --------------------------------------------- # # let's double check that the loss is calculated correctly # calculate the total loss assert torch.allclose( - loss_weights["per_molecule_energy"] * loss_output["per_molecule_energy/mse"] - + loss_weights["per_atom_force"] * loss_output["per_atom_force/mse"], + loss_weights["per_molecule_energy"] * loss_output["per_molecule_energy"] + + loss_weights["per_atom_force"] * loss_output["per_atom_force"], loss_output["total_loss"].to(torch.float32), ) diff --git a/modelforge/train/parameters.py b/modelforge/train/parameters.py index 8fe6c2c9..f6daf0d2 100644 --- a/modelforge/train/parameters.py +++ b/modelforge/train/parameters.py @@ -262,7 +262,7 @@ def ensure_logger_configuration(self) -> "ExperimentLogger": remove_self_energies: bool batch_size: int lr: float - monitor: str + monitor_for_checkpoint: str lr_scheduler: Optional[SchedulerConfig] = None loss_parameter: LossParameter early_stopping: Optional[EarlyStopping] = None @@ -270,7 +270,8 @@ def ensure_logger_configuration(self) -> "ExperimentLogger": stochastic_weight_averaging: Optional[StochasticWeightAveraging] = None experiment_logger: ExperimentLogger verbose: bool = False - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam + optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW + min_number_of_epochs: Union[int, None] = None ### Runtime Parameters diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 53f5df4b..23ab1a68 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -2,33 +2,39 @@ This module contains classes and functions for training neural network potentials using PyTorch Lightning. """ -from torch.optim.lr_scheduler import ReduceLROnPlateau -from typing import Any, Union, Dict, Type, Optional, List +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Type, Union + +import lightning.pytorch as pL import torch -from loguru import logger as log -from modelforge.dataset.dataset import BatchData, NNPInput import torchmetrics +from lightning import Trainer +from loguru import logger as log from torch import nn -from abc import ABC, abstractmethod -from modelforge.dataset.dataset import DatasetParameters +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from modelforge.dataset.dataset import ( + BatchData, + DataModule, + DatasetParameters, + NNPInput, +) from modelforge.potential.parameters import ( ANI2xParameters, - PhysNetParameters, - SchNetParameters, PaiNNParameters, + PhysNetParameters, SAKEParameters, + SchNetParameters, TensorNetParameters, ) -from lightning import Trainer -import lightning.pytorch as pL -from modelforge.dataset.dataset import DataModule +from modelforge.train.parameters import RuntimeParameters, TrainingParameters __all__ = [ "Error", - "FromPerAtomToPerMoleculeMeanSquaredError", + "FromPerAtomToPerMoleculeSquaredError", "Loss", "LossFactory", - "PerMoleculeMeanSquaredError", + "PerMoleculeSquaredError", "ModelTrainer", "create_error_metrics", "ModelTrainer", @@ -77,7 +83,8 @@ def calculate_squared_error( torch.Tensor The calculated error. """ - error = (predicted_tensor - reference_tensor).pow(2).sum(dim=1, keepdim=True) + squared_diff = (predicted_tensor - reference_tensor).pow(2) + error = squared_diff.sum(dim=1, keepdim=True) return error @staticmethod @@ -107,7 +114,7 @@ def scale_by_number_of_atoms( return scaled_by_number_of_atoms -class FromPerAtomToPerMoleculeMeanSquaredError(Error): +class FromPerAtomToPerMoleculeSquaredError(Error): """ Calculates the per-atom error and aggregates it to per-molecule mean squared error. """ @@ -163,18 +170,18 @@ def forward( 0, batch.nnp_input.atomic_subsystem_indices.long().unsqueeze(1), per_atom_squared_error, - ) + ).contiguous() # divide by number of atoms per_molecule_square_error_scaled = self.scale_by_number_of_atoms( per_molecule_squared_error, batch.metadata.atomic_subsystem_counts, prefactor=per_atom_prediction.shape[-1], - ) - # return the average - return torch.mean(per_molecule_square_error_scaled) + ).contiguous() + return per_molecule_square_error_scaled -class PerMoleculeMeanSquaredError(Error): + +class PerMoleculeSquaredError(Error): """ Calculates the per-molecule mean squared error. """ @@ -214,8 +221,7 @@ def forward( batch.metadata.atomic_subsystem_counts, ) - # return the average - return torch.mean(per_molecule_square_error_scaled) + return per_molecule_square_error_scaled def calculate_error( self, @@ -259,22 +265,24 @@ def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): for prop, w in weight.items(): if prop in self._SUPPORTED_PROPERTIES: if prop == "per_atom_force": - self.loss[prop] = FromPerAtomToPerMoleculeMeanSquaredError( + self.loss[prop] = FromPerAtomToPerMoleculeSquaredError( scale_by_number_of_atoms=True ) elif prop == "per_atom_energy": - self.loss[prop] = PerMoleculeMeanSquaredError( + self.loss[prop] = PerMoleculeSquaredError( scale_by_number_of_atoms=True ) # FIXME: this is currently not working elif prop == "per_molecule_energy": - self.loss[prop] = PerMoleculeMeanSquaredError( - scale_by_number_of_atoms=False + self.loss[prop] = PerMoleculeSquaredError( + scale_by_number_of_atoms=True ) self.register_buffer(prop, torch.tensor(w)) else: raise NotImplementedError(f"Loss type {prop} not implemented.") - def forward(self, predict_target: Dict[str, torch.Tensor], batch): + def forward( + self, predict_target: Dict[str, torch.Tensor], batch + ) -> Dict[str, torch.Tensor]: """ Calculates the combined loss for the specified properties. @@ -305,7 +313,7 @@ def forward(self, predict_target: Dict[str, torch.Tensor], batch): # add total loss loss = loss + (self.weight[prop] * loss_) # save loss - loss_dict[f"{prop}/mse"] = loss_ + loss_dict[f"{prop}"] = loss_ # add total loss to results dict and return loss_dict["total_loss"] = loss @@ -338,11 +346,11 @@ def create_loss(loss_property: List[str], weight: Dict[str, float]) -> Type[Loss return Loss(loss_property, weight) -from torch.optim import Optimizer from torch.nn import ModuleDict +from torch.optim import Optimizer -def create_error_metrics(loss_properties: List[str]) -> ModuleDict: +def create_error_metrics(loss_properties: List[str], loss: bool = False) -> ModuleDict: """ Creates a ModuleDict of MetricCollections for the given loss properties. @@ -350,50 +358,52 @@ def create_error_metrics(loss_properties: List[str]) -> ModuleDict: ---------- loss_properties : List[str] List of loss properties for which to create the metrics. - + loss : bool, optional + If True, only the loss metric is created, by default False. Returns ------- ModuleDict A dictionary where keys are loss properties and values are MetricCollections. """ - from torchmetrics.regression import ( - MeanAbsoluteError, - MeanSquaredError, - ) from torchmetrics import MetricCollection + from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError + from torchmetrics.aggregation import MeanMetric - return ModuleDict( - { - prop: MetricCollection( - [MeanAbsoluteError(), MeanSquaredError(squared=False)] - ) - for prop in loss_properties - } - ) - - -from modelforge.train.parameters import RuntimeParameters, TrainingParameters + if loss: + metric_dict = ModuleDict( + {prop: MetricCollection([MeanMetric()]) for prop in loss_properties} + ) + metric_dict["total_loss"] = MetricCollection([MeanMetric()]) + return metric_dict + else: + metric_dict = ModuleDict( + { + prop: MetricCollection( + [MeanAbsoluteError(), MeanSquaredError(squared=False)] + ) + for prop in loss_properties + } + ) + return metric_dict class CalculateProperties(torch.nn.Module): - def __init__(self): + def __init__(self, requested_properties: List[str]): """ A utility class for calculating properties such as energies and forces from batches using a neural network model. - Methods - ------- - _get_forces(batch: BatchData, energies: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor] - Computes the forces from a given batch using the model. - _get_energies(batch: BatchData, model: Type[torch.nn.Module]) -> Dict[str, torch.Tensor] - Computes the energies from a given batch using the model. - forward(batch: BatchData, model: Type[torch.nn.Module]) -> Dict[str, torch.Tensor] - Computes the energies and forces from a given batch using the model. + Parameters + """ super().__init__() + self.requested_properties = requested_properties + self.include_force = False + if "per_atom_force" in self.requested_properties: + self.include_force = True def _get_forces( - self, batch: "BatchData", energies: Dict[str, torch.Tensor] + self, batch: "BatchData", energies: Dict[str, torch.Tensor], train_mode: bool ) -> Dict[str, torch.Tensor]: """ Computes the forces from a given batch using the model. @@ -426,16 +436,22 @@ def _get_forces( # Compute the gradient (forces) from the predicted energies grad = torch.autograd.grad( - per_molecule_energy_predict.sum(), + per_molecule_energy_predict, nnp_input.positions, - create_graph=True, - retain_graph=True, + grad_outputs=torch.ones_like(per_molecule_energy_predict), + create_graph=train_mode, + retain_graph=train_mode, + allow_unused=True, )[0] + + if grad is None: + raise RuntimeWarning("Force calculation did not return a gradient") + per_atom_force_predict = -1 * grad # Forces are the negative gradient of energy return { "per_atom_force_true": per_atom_force_true, - "per_atom_force_predict": per_atom_force_predict, + "per_atom_force_predict": per_atom_force_predict.contiguous(), } def _get_energies( @@ -473,7 +489,7 @@ def _get_energies( } def forward( - self, batch: "BatchData", model: Type[torch.nn.Module] + self, batch: "BatchData", model: Type[torch.nn.Module], train_mode: bool = False ) -> Dict[str, torch.Tensor]: """ Computes the energies and forces from a given batch using the model. @@ -491,7 +507,10 @@ def forward( The true and predicted energies and forces from the dataset and the model. """ energies = self._get_energies(batch, model) - forces = self._get_forces(batch, energies) + if self.include_force: + forces = self._get_forces(batch, energies, train_mode) + else: + forces = {} return {**energies, **forces} @@ -543,7 +562,42 @@ def __init__( potential_seed=potential_seed, ) - self.calculate_predictions = CalculateProperties() + def check_strides(module, grad_input, grad_output): + print(f"Layer: {module.__class__.__name__}") + for i, grad in enumerate(grad_input): + if grad is not None: + print( + f"Grad input {i}: size {grad.size()}, strides {grad.stride()}" + ) + # Handle grad_output + if isinstance(grad_output, tuple) and isinstance(grad_output[0], dict): + # If the output is a dict wrapped in a tuple, extract the dict + grad_output = grad_output[0] + if isinstance(grad_output, dict): + for key, grad in grad_output.items(): + if grad is not None: + print( + f"Grad output [{key}]: size {grad.size()}, strides {grad.stride()}" + ) + else: + for i, grad in enumerate(grad_output): + if grad is not None: + print( + f"Grad output {i}: size {grad.size()}, strides {grad.stride()}" + ) + + # Register the full backward hook + if training_parameter.verbose is True: + for module in self.potential.modules(): + module.register_full_backward_hook(check_strides) + + self.include_force = False + if "per_atom_force" in training_parameter.loss_parameter.loss_property: + self.include_force = True + + self.calculate_predictions = CalculateProperties( + training_parameter.loss_parameter.loss_property + ) self.optimizer = training_parameter.optimizer self.learning_rate = training_parameter.lr self.lr_scheduler = training_parameter.lr_scheduler @@ -572,6 +626,11 @@ def __init__( training_parameter.loss_parameter.loss_property ) + # Initialize loss metric + self.loss_metric = create_error_metrics( + training_parameter.loss_parameter.loss_property, loss=True + ) + def forward(self, batch: "BatchData") -> Dict[str, torch.Tensor]: """ Computes the energies and forces from a given batch using the model. @@ -638,29 +697,23 @@ def training_step(self, batch: "BatchData", batch_idx: int) -> torch.Tensor: The loss tensor computed for the current training step. """ - # calculate energy and forces - predict_target = self.calculate_predictions(batch, self.potential) + # calculate energy and forces, Note that `predict_target` is a + # dictionary containing the predicted and true values for energy and + # force` + predict_target = self.calculate_predictions( + batch, self.potential, self.training + ) - # calculate the loss + # Calculate the loss loss_dict = self.loss(predict_target, batch) - # Update and log training error - self._update_metrics(self.train_error, predict_target) + # Update the loss metric with the different loss components + for key, metric in loss_dict.items(): + self.loss_metric[key].update(metric.clone().detach(), batch.batch_size()) - # log the loss (this includes the individual contributions that the loss contains) - for key, loss in loss_dict.items(): - self.log( - f"loss/{key}", - torch.mean(loss), - on_step=False, - prog_bar=True, - on_epoch=True, - batch_size=1, - ) # batch size is 1 because the mean of the batch is logged + loss = torch.mean(loss_dict["total_loss"]) + return loss.contiguous() - return loss_dict["total_loss"] - - @torch.enable_grad() def validation_step(self, batch: "BatchData", batch_idx: int) -> None: """ Validation step to compute the RMSE/MAE across epochs. @@ -679,14 +732,15 @@ def validation_step(self, batch: "BatchData", batch_idx: int) -> None: # Ensure positions require gradients for force calculation batch.nnp_input.positions.requires_grad_(True) - # calculate energy and forces - predict_target = self.calculate_predictions(batch, self.potential) - # calculate the loss - loss = self.loss(predict_target, batch) - # log the loss + with torch.inference_mode(False): + + # calculate energy and forces + predict_target = self.calculate_predictions( + batch, self.potential, self.training + ) + self._update_metrics(self.val_error, predict_target) - @torch.enable_grad() def test_step(self, batch: "BatchData", batch_idx: int) -> None: """ Test step to compute the RMSE loss for a given batch. @@ -709,7 +763,10 @@ def test_step(self, batch: "BatchData", batch_idx: int) -> None: # Ensure positions require gradients for force calculation batch.nnp_input.positions.requires_grad_(True) # calculate energy and forces - predict_target = self.calculate_predictions(batch, self.potential) + with torch.inference_mode(False): + predict_target = self.calculate_predictions( + batch, self.potential, self.training + ) # Update and log metrics self._update_metrics(self.test_error, predict_target) @@ -763,6 +820,7 @@ def _log_on_epoch(self, log_mode: str = "train"): conv = { "MeanAbsoluteError": "mae", "MeanSquaredError": "rmse", + "MeanMetric": "mse", # NOTE: MeanMetric is the MSE since we accumulate the squared error } # NOTE: MeanSquaredError(squared=False) is RMSE # Log all accumulated metrics for train and val phases @@ -770,6 +828,7 @@ def _log_on_epoch(self, log_mode: str = "train"): errors = [ ("train", self.train_error), ("val", self.val_error), + ("loss", self.loss_metric), ] elif log_mode == "test": errors = [ @@ -783,14 +842,11 @@ def _log_on_epoch(self, log_mode: str = "train"): if phase == "train" and not self.log_on_training_step: continue - metrics = {} for property, metrics_dict in error_dict.items(): for name, metric in metrics_dict.items(): - name = f"{phase}/{property}/{conv[name]}" - metrics[name] = metric.compute() + name = f"{phase}/{property}/{conv.get(name, name)}" + self.log(name, metric.compute(), prog_bar=True, sync_dist=True) metric.reset() - # log dict, print val metrics to console - self.log_dict(metrics, on_epoch=True, prog_bar=(phase == "val")) def configure_optimizers(self): """ @@ -947,8 +1003,8 @@ def setup_datamodule(self) -> DataModule: DataModule Configured DataModule instance. """ - from modelforge.dataset.utils import REGISTERED_SPLITTING_STRATEGIES from modelforge.dataset.dataset import DataModule + from modelforge.dataset.utils import REGISTERED_SPLITTING_STRATEGIES dm = DataModule( name=self.dataset_parameter.dataset_name, @@ -962,6 +1018,7 @@ def setup_datamodule(self) -> DataModule: seed=self.training_parameter.splitting_strategy.seed, split=self.training_parameter.splitting_strategy.data_split, ), + regenerate_processed_cache=self.dataset_parameter.regenerate_processed_cache, ) dm.prepare_data() dm.setup() @@ -1047,8 +1104,8 @@ def setup_callbacks(self) -> List[Any]: List of configured callbacks. """ from lightning.pytorch.callbacks import ( - ModelCheckpoint, EarlyStopping, + ModelCheckpoint, StochasticWeightAveraging, ) @@ -1071,7 +1128,7 @@ def setup_callbacks(self) -> List[Any]: ) checkpoint_callback = ModelCheckpoint( save_top_k=2, - monitor=self.training_parameter.monitor, + monitor=self.training_parameter.monitor_for_checkpoint, filename=checkpoint_filename, ) callbacks.append(checkpoint_callback) @@ -1088,8 +1145,13 @@ def setup_trainer(self) -> Trainer: """ from lightning import Trainer + # if devices is a list + if isinstance(self.runtime_parameter.devices, list): + strategy = "ddp" + trainer = Trainer( max_epochs=self.training_parameter.number_of_epochs, + min_epochs=self.training_parameter.min_number_of_epochs, num_nodes=self.runtime_parameter.number_of_nodes, devices=self.runtime_parameter.devices, accelerator=self.runtime_parameter.accelerator, @@ -1282,9 +1344,9 @@ def read_config( runtime_config_dict[key] = value # Load and instantiate the data classes with the merged configuration - from modelforge.potential import _Implemented_NNP_Parameters from modelforge.dataset.dataset import DatasetParameters - from modelforge.train.parameters import TrainingParameters, RuntimeParameters + from modelforge.potential import _Implemented_NNP_Parameters + from modelforge.train.parameters import RuntimeParameters, TrainingParameters potential_name = potential_config_dict["potential_name"] PotentialParameters = ( @@ -1382,9 +1444,7 @@ def read_config_and_train( log_every_n_steps=log_every_n_steps, simulation_environment=simulation_environment, ) - from modelforge.potential.models import ( - NeuralNetworkPotentialFactory, - ) + from modelforge.potential.models import NeuralNetworkPotentialFactory model = NeuralNetworkPotentialFactory.generate_potential( use="training", diff --git a/modelforge/utils/__init__.py b/modelforge/utils/__init__.py index 4e57d2f1..605aea59 100644 --- a/modelforge/utils/__init__.py +++ b/modelforge/utils/__init__.py @@ -1,3 +1,4 @@ """Module of general modelforge utilities.""" from .prop import SpeciesEnergies, PropertyNames +from .misc import lock_with_attribute diff --git a/modelforge/utils/misc.py b/modelforge/utils/misc.py index f56b2d8d..f4e918f2 100644 --- a/modelforge/utils/misc.py +++ b/modelforge/utils/misc.py @@ -2,15 +2,18 @@ Module of miscellaneous utilities. """ -from typing import Literal +from typing import Literal, TYPE_CHECKING import torch from loguru import logger -from modelforge.dataset.dataset import DataModule + +# import DataModule for typing hint +if TYPE_CHECKING: + from modelforge.dataset.dataset import DataModule def visualize_model( - dm: DataModule, + dm: "DataModule", potential_name: Literal["ANI2x", "PhysNet", "SchNet", "PaiNN", "SAKE"], ): # visualize the compute graph @@ -314,3 +317,58 @@ def __exit__(self, *args): # fcntl.flock(self._file_handle.fileno(), fcntl.LOCK_UN) unlock_file(self._file_handle) self._file_handle.close() + + +import os +from functools import wraps + + +def lock_with_attribute(attribute_name): + """ + Decorator for locking a method using a lock file path stored in an instance + attribute. The attribute is accessed on the instance (`self`) at runtime. + + Parameters + ---------- + attribute_name : str + The name of the instance attribute that contains the lock file path. + + Examples + -------- + >>> from modelforge.utils.misc import lock_with_attribute + >>> + >>> class MyClass: + >>> def __init__(self, lock_file): + >>> self.method_lock = lock_file + >>> + >>> @lock_with_attribute('method_lock') + >>> def critical_section(self): + >>> print("Executing critical section") + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Retrieve the instance (`self`) + instance = args[0] + # Get the lock file path from the specified attribute + lock_file_path = getattr(instance, attribute_name) + with open(lock_file_path, "w+") as f: + # Lock the file + lock_file(f) + + try: + # Execute the wrapped function + result = func(*args, **kwargs) + finally: + # Unlock the file + unlock_file(f) + + # Optionally, remove the lock file + os.remove(lock_file_path) + + return result + + return wrapper + + return decorator diff --git a/scripts/config.toml b/scripts/config.toml index 32e3018b..45e8d64c 100644 --- a/scripts/config.toml +++ b/scripts/config.toml @@ -1,41 +1,39 @@ [potential] -potential_name = "SchNet" +potential_name = "ANI2x" [potential.core_parameter] -number_of_radial_basis_functions = 20 -maximum_interaction_radius = "5.0 angstrom" -number_of_interaction_modules = 3 -number_of_filters = 32 -shared_interactions = false +angle_sections = 4 +maximum_interaction_radius = "5.1 angstrom" +minimum_interaction_radius = "0.8 angstrom" +number_of_radial_basis_functions = 16 +maximum_interaction_radius_for_angular_features = "3.5 angstrom" +minimum_interaction_radius_for_angular_features = "0.8 angstrom" +angular_dist_divisions = 8 [potential.core_parameter.activation_function_parameter] -activation_function_name = "ShiftedSoftplus" +activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used. -[potential.core_parameter.featurization] -properties_to_featurize = ['atomic_number'] -maximum_atomic_number = 101 -number_of_per_atom_features = 32 +[potential.core_parameter.activation_function_parameter.activation_function_arguments] +alpha = 0.1 [potential.postprocessing_parameter] [potential.postprocessing_parameter.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true -[potential.postprocessing_parameter.general_postprocessing_operation] -calculate_molecular_self_energy = true [dataset] -dataset_name = "QM9" +dataset_name = "PHALKETHOH" version_select = "nc_1000_v0" num_workers = 4 pin_memory = true [training] -number_of_epochs = 4 +number_of_epochs = 1000 remove_self_energies = true -batch_size = 128 -lr = 1e-3 -monitor = "val/per_molecule_energy/rmse" +batch_size = 16 +lr = 0.5e-3 +monitor_for_checkpoint = "val/per_molecule_energy/rmse" [training.experiment_logger] logger_name = "tensorboard" @@ -62,7 +60,6 @@ loss_property = ['per_molecule_energy', 'per_atom_force'] # use per_molecule_energy = 0.999 #NOTE: reciprocal units per_atom_force = 0.001 - [training.early_stopping] verbose = true monitor = "val/per_molecule_energy/rmse" @@ -75,12 +72,12 @@ data_split = [0.8, 0.1, 0.1] seed = 42 [runtime] -save_dir = "lightning_logs" +save_dir = "test_setup" experiment_name = "{potential_name}_{dataset_name}" local_cache_dir = "./cache" accelerator = "cpu" number_of_nodes = 1 -devices = 1 #[0,1,2,3] +devices = 1 #[0,1,2,3] checkpoint_path = "None" simulation_environment = "PyTorch" log_every_n_steps = 1