Skip to content

Commit

Permalink
Merging ModelTrainer and TrainingAdapter class (#225)
Browse files Browse the repository at this point in the history
* add tags to wandb logger
* replace experimental_name placehodler
* combine Trainer and TrainerAdaptor class
* update tests, names and class logic
* update docstrings, removed unsused methods
* change `model_name` to `potential_name`
* Fixed minor naming error; replaced Fire automatic command line generation with manually set up command line interface using argparse.
* changing model_seed to potential_seed; added docstrings
* setting random number seeds
* If neither runtime paramters and simulation environment, default to pytorch
* Added potential_seed to potential parameters

---------

Co-authored-by: chrisiacovella <[email protected]>
  • Loading branch information
wiederm and chrisiacovella authored Aug 13, 2024
1 parent 620b509 commit 951a53b
Show file tree
Hide file tree
Showing 19 changed files with 944 additions and 634 deletions.
4 changes: 4 additions & 0 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@ class ANI2x(BaseNetwork):
Configuration for postprocessing parameters.
dataset_statistic : Optional[Dict[str, float]], optional
Statistics of the dataset, by default None.
potential_seed : Optional[int], optional
Seed for the random number generator, default None.
"""

def __init__(
Expand All @@ -664,6 +666,7 @@ def __init__(
activation_function_parameter: Dict,
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
) -> None:

from modelforge.utils.units import _convert_str_to_unit
Expand All @@ -674,6 +677,7 @@ def __init__(
dataset_statistic=dataset_statistic,
postprocessing_parameter=postprocessing_parameter,
maximum_interaction_radius=_convert_str_to_unit(maximum_interaction_radius),
potential_seed=potential_seed,
)

activation_function = activation_function_parameter["activation_function"]
Expand Down
102 changes: 81 additions & 21 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,19 +540,41 @@ def apply_bwd(res, grads):
return apply, model_params, model_buffer


from modelforge.potential.parameters import (
ANI2xParameters,
PhysNetParameters,
SchNetParameters,
PaiNNParameters,
SAKEParameters,
TensorNetParameters,
)
from modelforge.train.parameters import TrainingParameters, RuntimeParameters
from modelforge.dataset.dataset import DatasetParameters


class NeuralNetworkPotentialFactory:
"""
Factory class for creating instances of neural network potentials for training/inference.
"""

@staticmethod
def generate_model(
def generate_potential(
*,
use: Literal["training", "inference"],
model_parameter: Dict[str, Union[str, Any]],
simulation_environment: Literal["PyTorch", "JAX"] = "PyTorch",
training_parameter: Optional[Dict[str, Any]] = None,
potential_parameter: Union[
ANI2xParameters,
SAKEParameters,
SchNetParameters,
PhysNetParameters,
PaiNNParameters,
TensorNetParameters,
],
runtime_parameter: Optional[RuntimeParameters] = None,
training_parameter: Optional[TrainingParameters] = None,
dataset_parameter: Optional[DatasetParameters] = None,
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
simulation_environment: Optional[Literal["PyTorch", "JAX"]] = None,
) -> Union[Type[torch.nn.Module], Type[JAXModel], Type[pl.LightningModule]]:
"""
Creates an NNP instance of the specified type, configured either for training or inference.
Expand All @@ -561,14 +583,21 @@ def generate_model(
----------
use : Literal["training", "inference"]
The use case for the model instance, either 'training' or 'inference'.
simulation_environment : Literal["PyTorch", "JAX"]
The ML framework to use, either 'PyTorch' or 'JAX'.
model_parameter : Dict[str, Union[str, Any]]
Parameters specific to the model.
training_parameter : Optional[Dict[str, Any]], optional
potential_parameter : Union[ANI2xParameters, SAKEParameters, SchNetParameters, PhysNetParameters, PaiNNParameters, TensorNetParameters]
Parameters specific to the potential.
runtime_parameter : Optional[RuntimeParameters], optional
Parameters for configuring the runtime environment.
training_parameter : Optional[TrainingParameters], optional
Parameters for configuring the training.
dataset_parameter : Optional[DatasetParameters], optional
Parameters for configuring the dataset.
dataset_statistic : Optional[Dict[str, float]], optional
Statistics of the dataset for normalization purposes.
potential_seed : Optional[int], optional
Seed for the random number generator.
simulation_environment : Optional[Literal["PyTorch", "JAX"]], optional, None
The simulation environment to use for training/inference. Will override the runtime parameter if provided.
Returns
-------
Expand All @@ -584,33 +613,44 @@ def generate_model(
"""

from modelforge.potential import _Implemented_NNPs
from modelforge.train.training import TrainingAdapter
from modelforge.train.training import ModelTrainer

log.debug(f"{training_parameter=}")
log.debug(f"{model_parameter=}")
log.debug(f"{potential_parameter=}")
log.debug(f"{dataset_parameter=}")

if simulation_environment is None:
if runtime_parameter is None:
log.warning(
"No runtime paramters or simulation_environment specified, defaulting to PyTorch"
)

simulation_environment = "PyTorch"
else:
simulation_environment = runtime_parameter.simulation_environment
# obtain model for training
if use == "training":
if simulation_environment == "JAX":
log.warning(
"Training in JAX is not available. Falling back to PyTorch."
)
model = TrainingAdapter(
model_parameter=model_parameter,
lr_scheduler=training_parameter["lr_scheduler"],
lr=training_parameter["lr"],
loss_parameter=training_parameter["loss_parameter"],
dataset_statistic=dataset_statistic,
model = ModelTrainer(
potential_parameter=potential_parameter,
training_parameter=training_parameter,
dataset_parameter=dataset_parameter,
runtime_parameter=runtime_parameter,
potential_seed=potential_seed,
)
return model
# obtain model for inference
elif use == "inference":
model_type = model_parameter["potential_name"]
model_type = potential_parameter.potential_name
nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_type)
model = nnp_class(
**model_parameter["core_parameter"],
postprocessing_parameter=model_parameter["postprocessing_parameter"],
**potential_parameter.core_parameter.model_dump(),
postprocessing_parameter=potential_parameter.postprocessing_parameter.model_dump(),
dataset_statistic=dataset_statistic,
potential_seed=potential_seed,
)
if simulation_environment == "JAX":
return PyTorch2JAXConverter().convert_to_jax_model(model)
Expand Down Expand Up @@ -954,6 +994,7 @@ def __init__(
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]],
maximum_interaction_radius: unit.Quantity,
potential_seed: Optional[int] = None,
):
"""
Initialize the BaseNetwork.
Expand All @@ -966,11 +1007,30 @@ def __init__(
Dataset statistics for normalization.
maximum_interaction_radius : unit.Quantity
cutoff radius.
potential_seed : Optional[int], optional
Value used for torch.manual_seed, by default None.
"""

super().__init__()
from modelforge.utils.units import _convert_str_to_unit

if potential_seed:
import torch

torch.manual_seed(potential_seed)

# according to https://docs.ray.io/en/latest/tune/faq.html#how-can-i-reproduce-experiments
# we should also set the same seed for numpy.random and python random module
# when doing hyperparameter optimization with ray tune. E.g., the ASHA scheduler relies on numpy.random
# and doesn't take a seed as an argument, so we need to set it here.
import numpy as np

np.random.seed(potential_seed)

import random

random.seed(potential_seed)

self.postprocessing = PostProcessing(
postprocessing_parameter, dataset_statistic
)
Expand Down Expand Up @@ -1012,7 +1072,7 @@ def load_state_dict(
"""

# Prefix to remove
prefix = "model."
prefix = "potential."
excluded_keys = ["loss.per_molecule_energy", "loss.per_atom_force"]

# Create a new dictionary without the prefix in the keys if prefix exists
Expand Down
5 changes: 4 additions & 1 deletion modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def __init__(
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]] = None,
epsilon: float = 1e-8,
potential_seed: Optional[int] = None,
) -> None:
"""
Initialize the PaiNN network.
Expand All @@ -582,7 +583,8 @@ def __init__(
shared_filters : bool
Whether to share filters across modules.
epsilon=epsilon,
)
potential_seed : Optional[int], optional
Seed for the random number generator, default None.
"""

from modelforge.utils.units import _convert_str_to_unit
Expand All @@ -593,6 +595,7 @@ def __init__(
dataset_statistic=dataset_statistic,
postprocessing_parameter=postprocessing_parameter,
maximum_interaction_radius=_convert_str_to_unit(maximum_interaction_radius),
potential_seed=potential_seed,
)

activation_function = activation_function_parameter["activation_function"]
Expand Down
6 changes: 6 additions & 0 deletions modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class PostProcessingParameter(ParametersBase):
potential_name: str = "ANI2x"
core_parameter: CoreParameter
postprocessing_parameter: PostProcessingParameter
potential_seed: Optional[int] = None


class SchNetParameters(ParametersBase):
Expand Down Expand Up @@ -183,6 +184,7 @@ class PostProcessingParameter(ParametersBase):
potential_name: str = "SchNet"
core_parameter: CoreParameter
postprocessing_parameter: PostProcessingParameter
potential_seed: Optional[int] = None


class TensorNetParameters(ParametersBase):
Expand Down Expand Up @@ -214,6 +216,7 @@ class PostProcessingParameter(ParametersBase):
potential_name: str = "TensorNet"
core_parameter: CoreParameter
postprocessing_parameter: PostProcessingParameter
potential_seed: Optional[int] = None


class PaiNNParameters(ParametersBase):
Expand Down Expand Up @@ -244,6 +247,7 @@ class PostProcessingParameter(ParametersBase):
potential_name: str = "PaiNN"
core_parameter: CoreParameter
postprocessing_parameter: PostProcessingParameter
potential_seed: Optional[int] = None


class PhysNetParameters(ParametersBase):
Expand Down Expand Up @@ -273,6 +277,7 @@ class PostProcessingParameter(ParametersBase):
potential_name: str = "PhysNet"
core_parameter: CoreParameter
postprocessing_parameter: PostProcessingParameter
potential_seed: Optional[int] = None


class SAKEParameters(ParametersBase):
Expand Down Expand Up @@ -302,3 +307,4 @@ class PostProcessingParameter(ParametersBase):
potential_name: str = "SAKE"
core_parameter: CoreParameter
postprocessing_parameter: PostProcessingParameter
potential_seed: Optional[int] = None
4 changes: 4 additions & 0 deletions modelforge/potential/physnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,8 @@ class PhysNet(BaseNetwork):
Configuration for postprocessing parameters.
dataset_statistic : Optional[Dict[str, float]], optional
Statistics of the dataset, by default None.
potential_seed : Optional[int], optional
Seed for the random number generator, default None.
"""

def __init__(
Expand All @@ -687,13 +689,15 @@ def __init__(
activation_function_parameter: Dict,
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
) -> None:

self.only_unique_pairs = False # NOTE: for pairlist
super().__init__(
dataset_statistic=dataset_statistic,
postprocessing_parameter=postprocessing_parameter,
maximum_interaction_radius=_convert_str_to_unit(maximum_interaction_radius),
potential_seed=potential_seed,
)
activation_function = activation_function_parameter["activation_function"]

Expand Down
2 changes: 2 additions & 0 deletions modelforge/potential/sake.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def __init__(
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]] = None,
epsilon: float = 1e-8,
potential_seed: Optional[int] = None,
):
from modelforge.utils.units import _convert_str_to_unit

Expand All @@ -610,6 +611,7 @@ def __init__(
dataset_statistic=dataset_statistic,
postprocessing_parameter=postprocessing_parameter,
maximum_interaction_radius=_convert_str_to_unit(maximum_interaction_radius),
potential_seed=potential_seed,
)
activation_function = activation_function_parameter["activation_function"]

Expand Down
4 changes: 4 additions & 0 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ class SchNet(BaseNetwork):
Configuration for postprocessing parameters.
dataset_statistic : Optional[Dict[str, float]], default=None
Statistics of the dataset.
potential_seed : Optional[int], optional
Seed for the random number generator, default None.
"""

def __init__(
Expand All @@ -441,6 +443,7 @@ def __init__(
shared_interactions: bool,
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
) -> None:

self.only_unique_pairs = False # NOTE: need to be set before super().__init__
Expand All @@ -449,6 +452,7 @@ def __init__(
dataset_statistic=dataset_statistic,
postprocessing_parameter=postprocessing_parameter,
maximum_interaction_radius=_convert_str_to_unit(maximum_interaction_radius),
potential_seed=potential_seed,
)

activation_function = activation_function_parameter["activation_function"]
Expand Down
4 changes: 4 additions & 0 deletions modelforge/potential/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class TensorNet(BaseNetwork):
Postprocessing parameters.
dataset_statistic : Optional[Dict[str, float]]
Dataset statistics.
potential_seed : Optional[int], optional
Seed for the random number generator, default None.
"""

def __init__(
Expand All @@ -217,6 +219,7 @@ def __init__(
activation_function_parameter: Dict,
postprocessing_parameter: Dict[str, Dict[str, bool]],
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
) -> None:

activation_function = activation_function_parameter["activation_function"]
Expand All @@ -226,6 +229,7 @@ def __init__(
dataset_statistic=dataset_statistic,
postprocessing_parameter=postprocessing_parameter,
maximum_interaction_radius=_convert_str_to_unit(maximum_interaction_radius),
potential_seed=potential_seed,
)

self.core_module = TensorNetCore(
Expand Down
4 changes: 2 additions & 2 deletions modelforge/tests/data/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ seed = 42

[runtime]
save_dir = "lightning_logs"
experiment_name = "test_exp"
experiment_name = "{potential_name}_{dataset_name}"
local_cache_dir = "./cache"
accelerator = "cpu"
number_of_nodes = 1
devices = 1 #[0,1,2,3]
checkpoint_path = "None"
simulation_environment = "PyTorch"
log_every_n_steps = 50
log_every_n_steps = 1
2 changes: 1 addition & 1 deletion modelforge/tests/data/runtime_defaults/runtime.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[runtime]
save_dir = "lightning_logs"
experiment_name = "test_exp2"
experiment_name = "{potential_name}_{dataset_name}"
local_cache_dir = "./cache"
accelerator = "cpu"
number_of_nodes = 1
Expand Down
Loading

0 comments on commit 951a53b

Please sign in to comment.