forked from IntelLabs/matsciml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request IntelLabs#246 from melo-gonzo/experiments-refactor…
…-branch Experiments Running and Management
- Loading branch information
Showing
36 changed files
with
1,275 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Experiments in Open MatSciML Toolkit | ||
|
||
Experimental workflows may be time consuming, repetitive, and complex to set up. Additionally, pytorch-lightning based cli utilities may not be able to handle specific use cases such as multi-data or multi-task training in matsciml. The experiments module of MatSciML is meant to loosely mirror the functionality of the pytorch lightning cli while allowing more flexibility in setting up complex experiments. Yaml files define the module parameters, and specific arguments may be change via the command line if desired. A single command is used to launch training runs which take out the complexity of writing up new script for each experiment type. | ||
|
||
## Experiment Config | ||
The starting point of defining an experiment is the experiment config. This is a yaml file that lays out what model, dataset(s), and task(s) will be used during training. An example config for single task training yaml (`single_task.yaml`) look like this: | ||
```yaml | ||
model: egnn_dgl | ||
dataset: | ||
oqmd: | ||
scalar_regression: | ||
- energy | ||
``` | ||
In general, and experiment may the be launched by running: | ||
`python experiments/training_script.py --experiment_config ./experiments/single_task.yaml` | ||
|
||
|
||
* The `model` field points to a specify `model.yaml` file in `./experiments/models`. | ||
* The `dataset` field is a dictionary specifying which datasets to use, as well as which tasks are associated with the parent dataset. | ||
* Tasks are referred to by their class name: | ||
```python | ||
ScalarRegressionTask | ||
ForceRegressionTask | ||
BinaryClassificationTask | ||
CrystalSymmetryClassificationTask | ||
GradFreeForceRegressionTask | ||
``` | ||
* A dataset may contain more than one task (single data, multi task learning) | ||
* Multiple datasets can be provided, each containing their own tasks (multi data, multi task learning) | ||
* For a list of available datasets, tasks, and models run `python training_script.py --options`. | ||
|
||
## Trainer Config | ||
The training config contains a few sections used for defining how experiments will be run. The debug tag is used to set parameters that should be used when debugging an experimental setup, or when working through bugs in setting up a new model or dataset. These parameters are helpful for doing quick end-to-end runs to make sure the pipeline is functional. The experiment tag is used to define parameters for the full experiment runs. Finally the generic tag used to define parameters used regardless of going through a debug or full experimental run. | ||
|
||
In addition to the experiment types, any other parameters to be used with the pytorch lightning `Trainer` should be added here. In the example `trainer.yaml`, there are callbacks and a logger. These objects are set up by adding their `class_path` as well as any `init_args` they expect. | ||
```yaml | ||
generic: | ||
min_epochs: 15 | ||
max_epochs: 100 | ||
debug: | ||
accelerator: cpu | ||
limit_train_batches: 10 | ||
limit_val_batches: 10 | ||
log_every_n_steps: 1 | ||
max_epochs: 2 | ||
experiment: | ||
accelerator: gpu | ||
strategy: ddp_find_unused_parameters_true | ||
callbacks: | ||
- class_path: pytorch_lightning.callbacks.EarlyStopping | ||
init_args: | ||
patience: 5 | ||
monitor: val_energy | ||
mode: min | ||
verbose: True | ||
check_finite: False | ||
- class_path: pytorch_lightning.callbacks.ModelCheckpoint | ||
init_args: | ||
monitor: val_energy | ||
save_top_k: 3 | ||
- class_path: matsciml.lightning.callbacks.GradientCheckCallback | ||
- class_path: matsciml.lightning.callbacks.SAM | ||
loggers: | ||
- class_path: pytorch_lightning.loggers.CSVLogger # can omit init_args['save_dir'] for auto directory | ||
``` | ||
|
||
|
||
|
||
## Dataset Config | ||
Similar to the trainer config, the dataset config has sections for debug and full experiments. Dataset paths, batch size, num workers, seed, and other relevant arguments may be set here. The available target keys for training are included. Other settings such as `normalization_kwargs` and `task_loss_scaling` may be set here under the `task_args` tag. | ||
```yaml | ||
dataset: CMDataset | ||
debug: | ||
batch_size: 4 | ||
num_workers: 0 | ||
experiment: | ||
test_split: '' | ||
train_path: '' | ||
val_split: '' | ||
target_keys: | ||
- energy | ||
- symmetry_number | ||
- symmetry_symbol | ||
task_args: | ||
normalize_kwargs: | ||
energy_mean: 1 | ||
energy_std: 1 | ||
task_loss_scaling: | ||
energy: 1 | ||
``` | ||
|
||
## Model Config | ||
Models available in matsciml my be DGL, PyG, or pointcloud based. Each model it named with its supported backend, as models may have more than one variety. In some instances, similar to the `trainer.yaml` config, a `class_path` and `init_args` need to be specified. Additionally, modules may need to be specified without initialization which may be done by using the `class_instance` tag. Finally, all transforms that a model should use should be included in the model config. | ||
```yaml | ||
encoder_class: | ||
class_path: matsciml.models.TensorNet | ||
encoder_kwargs: | ||
element_types: | ||
class_path: matsciml.datasets.utils.element_types | ||
num_rbf: 32 | ||
max_n: 3 | ||
max_l: 3 | ||
output_kwargs: | ||
lazy: False | ||
input_dim: 64 | ||
hidden_dim: 64 | ||
transforms: | ||
- class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform | ||
init_args: | ||
cutoff_radius: 6.5 | ||
adaptive_cutoff: True | ||
- class_path: matsciml.datasets.transforms.PointCloudToGraphTransform | ||
init_args: | ||
backend: dgl | ||
cutoff_dist: 20.0 | ||
node_keys: | ||
- "pos" | ||
- "atomic_numbers" | ||
``` | ||
|
||
## CLI Parameter Updates | ||
Certain parameters may be updated using the cli. The `-c, --cli_args` argument may be used, and the parameter must be specified as `[config].parameter.value`. The config may be `trainer`, `model`, or `dataset`. For example, to update the batch size for a debug run: | ||
|
||
`python training_script.py --debug --cli_args dataset.debug.batch_size.16` | ||
|
||
Only arguments which contain dict: [str, int, float] mapping all the way through to the target value may be updated. Parameters which map to lists at any point are not updatable through `cli_args`, for example callbacks, loggers, and transforms. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import yaml | ||
|
||
from pathlib import Path | ||
|
||
yaml_dir = Path(__file__).parent | ||
|
||
|
||
available_data = { | ||
"generic": { | ||
"experiment": {"batch_size": 32, "num_workers": 16}, | ||
"debug": {"batch_size": 4, "num_workers": 0}, | ||
}, | ||
} | ||
|
||
|
||
for filename in yaml_dir.rglob("*.yaml"): | ||
file_path = yaml_dir.joinpath(filename) | ||
with open(file_path, "r") as file: | ||
content = yaml.safe_load(file) | ||
file_key = file_path.stem | ||
available_data[file_key] = content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
dataset: CMDataset | ||
debug: | ||
batch_size: 4 | ||
num_workers: 0 | ||
experiment: | ||
test_split: '' | ||
train_path: '' | ||
val_split: '' | ||
target_keys: | ||
- energy | ||
- symmetry_number | ||
- symmetry_symbol | ||
task_args: | ||
normalize_kwargs: | ||
energy_mean: 1 | ||
energy_std: 1 | ||
task_loss_scaling: | ||
energy: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
from typing import Any | ||
import sys | ||
|
||
import pytorch_lightning as pl | ||
|
||
from matsciml.lightning.data_utils import ( | ||
MatSciMLDataModule, | ||
) | ||
from matsciml.datasets import * # noqa: F401 | ||
|
||
from matsciml.lightning.data_utils import MultiDataModule | ||
|
||
from experiments.datasets import available_data | ||
from experiments.models import available_models | ||
from experiments.utils.utils import instantiate_arg_dict, update_arg_dict | ||
|
||
|
||
def setup_datamodule(config: dict[str, Any]) -> pl.LightningModule: | ||
model = config["model"] | ||
data_task_dict = config["dataset"] | ||
run_type = config["run_type"] | ||
model = instantiate_arg_dict(deepcopy(available_models[model])) | ||
model = update_arg_dict("model", model, config["cli_args"]) | ||
datasets = list(data_task_dict.keys()) | ||
if len(datasets) == 1: | ||
dset = deepcopy(available_data[datasets[0]]) | ||
dset = update_arg_dict("dataset", dset, config["cli_args"]) | ||
dm_kwargs = deepcopy(available_data["generic"]["experiment"]) | ||
dm_kwargs.update(dset[run_type]) | ||
if run_type == "debug": | ||
dm = MatSciMLDataModule.from_devset( | ||
dataset=dset["dataset"], | ||
dset_kwargs={"transforms": model["transforms"]}, | ||
**dm_kwargs, | ||
) | ||
else: | ||
dm = MatSciMLDataModule( | ||
dataset=dset["dataset"], | ||
dset_kwargs={"transforms": model["transforms"]}, | ||
**dm_kwargs, | ||
) | ||
else: | ||
dset_list = {"train": [], "val": [], "test": []} | ||
for dataset in datasets: | ||
dset = deepcopy(available_data[dataset]) | ||
dset = update_arg_dict("dataset", dset, config["cli_args"]) | ||
dm_kwargs = deepcopy(available_data["generic"]["experiment"]) | ||
dset[run_type].pop("normalize_kwargs", None) | ||
dm_kwargs.update(dset[run_type]) | ||
dataset_name = dset["dataset"] | ||
dataset = getattr(sys.modules[__name__], dataset_name) | ||
model_transforms = model["transforms"] | ||
if run_type == "debug": | ||
dset_list["train"].append( | ||
dataset.from_devset(transforms=model_transforms) | ||
) | ||
dset_list["val"].append( | ||
dataset.from_devset(transforms=model_transforms) | ||
) | ||
dset_list["test"].append( | ||
dataset.from_devset(transforms=model_transforms) | ||
) | ||
else: | ||
if "train_path" in dm_kwargs: | ||
dset_list["train"].append( | ||
dataset(dm_kwargs["train_path"], transforms=model_transforms) | ||
) | ||
if "val_split" in dm_kwargs: | ||
dset_list["val"].append( | ||
dataset(dm_kwargs["val_split"], transforms=model_transforms) | ||
) | ||
if "test_split" in dm_kwargs: | ||
dset_list["test"].append( | ||
dataset(dm_kwargs["test_split"], transforms=model_transforms) | ||
) | ||
dm = MultiDataModule( | ||
train_dataset=MultiDataset(dset_list["train"]), | ||
val_dataset=MultiDataset(dset_list["val"]), | ||
test_dataset=MultiDataset(dset_list["test"]), | ||
batch_size=dm_kwargs["batch_size"], | ||
num_workers=dm_kwargs["num_workers"], | ||
) | ||
return dm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
dataset: IS2REDataset | ||
debug: | ||
batch_size: 4 | ||
num_workers: 0 | ||
experiment: | ||
train_path: '' | ||
val_split: '' | ||
target_keys: | ||
- energy_init | ||
- energy_relaxed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
dataset: LiPSDataset | ||
debug: | ||
batch_size: 4 | ||
num_workers: 0 | ||
experiment: | ||
test_split: '' | ||
train_path: '' | ||
val_split: '' | ||
target_keys: | ||
- energy | ||
- force |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
dataset: MaterialsProjectDataset | ||
debug: | ||
batch_size: 4 | ||
num_workers: 0 | ||
experiment: | ||
test_split: "" | ||
train_path: "" | ||
val_split: "" | ||
target_keys: | ||
- is_magnetic | ||
- is_metal | ||
- is_stable | ||
- band_gap | ||
- efermi | ||
- energy_per_atom | ||
- formation_energy_per_atom | ||
- uncorrected_energy_per_atom | ||
- symmetry_number | ||
- symmetry_symbol | ||
- symmetry_group |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
dataset: NomadDataset | ||
debug: | ||
batch_size: 4 | ||
num_workers: 0 | ||
experiment: | ||
test_split: '' | ||
train_path: '' | ||
val_split: '' | ||
target_keys: | ||
- spin_polarized | ||
- efermi | ||
- relative_energy | ||
- symmetry_number | ||
- symmetry_symbol | ||
- symmetry_group |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
dataset: OQMDDataset | ||
debug: | ||
batch_size: 4 | ||
num_workers: 0 | ||
experiment: | ||
test_split: '' | ||
train_path: '' | ||
val_split: '' | ||
target_keys: | ||
- band_gap | ||
- energy | ||
- stability |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
dataset: S2EFDataset | ||
debug: | ||
batch_size: 4 | ||
num_workers: 0 | ||
experiment: | ||
train_path: '' | ||
val_split: '' | ||
target_keys: | ||
- energy | ||
- force |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
|
||
import pytest | ||
|
||
import matsciml | ||
import matsciml.datasets.transforms # noqa: F401 | ||
from experiments.datasets.data_module_config import setup_datamodule | ||
|
||
|
||
single_task = { | ||
"model": "egnn_dgl", | ||
"dataset": {"oqmd": [{"task": "ScalarRegressionTask", "targets": ["band_gap"]}]}, | ||
} | ||
multi_task = { | ||
"dataset": { | ||
"s2ef": [ | ||
{"task": "ScalarRegressionTask", "targets": ["energy"]}, | ||
{"task": "ForceRegressionTask", "targets": ["force"]}, | ||
] | ||
} | ||
} | ||
multi_data = { | ||
"model": "faenet_pyg", | ||
"dataset": { | ||
"oqmd": [{"task": "ScalarRegressionTask", "targets": ["energy"]}], | ||
"is2re": [ | ||
{ | ||
"task": "ScalarRegressionTask", | ||
"targets": ["energy_init", "energy_relaxed"], | ||
} | ||
], | ||
}, | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("task_dict", [single_task, multi_task, multi_data]) | ||
def test_task_setup(task_dict): | ||
other_args = {"run_type": "debug", "model": "m3gnet_dgl", "cli_args": None} | ||
task_dict.update(other_args) | ||
setup_datamodule(config=task_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
model: faenet_pyg | ||
dataset: | ||
oqmd: | ||
- task: ScalarRegressionTask | ||
targets: | ||
- energy |
Oops, something went wrong.