Skip to content

Commit

Permalink
Merge pull request IntelLabs#246 from melo-gonzo/experiments-refactor…
Browse files Browse the repository at this point in the history
…-branch

Experiments Running and Management
  • Loading branch information
laserkelvin authored Jul 15, 2024
2 parents 0f58001 + cb51078 commit 651a2b5
Show file tree
Hide file tree
Showing 36 changed files with 1,275 additions and 5 deletions.
127 changes: 127 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Experiments in Open MatSciML Toolkit

Experimental workflows may be time consuming, repetitive, and complex to set up. Additionally, pytorch-lightning based cli utilities may not be able to handle specific use cases such as multi-data or multi-task training in matsciml. The experiments module of MatSciML is meant to loosely mirror the functionality of the pytorch lightning cli while allowing more flexibility in setting up complex experiments. Yaml files define the module parameters, and specific arguments may be change via the command line if desired. A single command is used to launch training runs which take out the complexity of writing up new script for each experiment type.

## Experiment Config
The starting point of defining an experiment is the experiment config. This is a yaml file that lays out what model, dataset(s), and task(s) will be used during training. An example config for single task training yaml (`single_task.yaml`) look like this:
```yaml
model: egnn_dgl
dataset:
oqmd:
scalar_regression:
- energy
```
In general, and experiment may the be launched by running:
`python experiments/training_script.py --experiment_config ./experiments/single_task.yaml`


* The `model` field points to a specify `model.yaml` file in `./experiments/models`.
* The `dataset` field is a dictionary specifying which datasets to use, as well as which tasks are associated with the parent dataset.
* Tasks are referred to by their class name:
```python
ScalarRegressionTask
ForceRegressionTask
BinaryClassificationTask
CrystalSymmetryClassificationTask
GradFreeForceRegressionTask
```
* A dataset may contain more than one task (single data, multi task learning)
* Multiple datasets can be provided, each containing their own tasks (multi data, multi task learning)
* For a list of available datasets, tasks, and models run `python training_script.py --options`.

## Trainer Config
The training config contains a few sections used for defining how experiments will be run. The debug tag is used to set parameters that should be used when debugging an experimental setup, or when working through bugs in setting up a new model or dataset. These parameters are helpful for doing quick end-to-end runs to make sure the pipeline is functional. The experiment tag is used to define parameters for the full experiment runs. Finally the generic tag used to define parameters used regardless of going through a debug or full experimental run.

In addition to the experiment types, any other parameters to be used with the pytorch lightning `Trainer` should be added here. In the example `trainer.yaml`, there are callbacks and a logger. These objects are set up by adding their `class_path` as well as any `init_args` they expect.
```yaml
generic:
min_epochs: 15
max_epochs: 100
debug:
accelerator: cpu
limit_train_batches: 10
limit_val_batches: 10
log_every_n_steps: 1
max_epochs: 2
experiment:
accelerator: gpu
strategy: ddp_find_unused_parameters_true
callbacks:
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
patience: 5
monitor: val_energy
mode: min
verbose: True
check_finite: False
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
monitor: val_energy
save_top_k: 3
- class_path: matsciml.lightning.callbacks.GradientCheckCallback
- class_path: matsciml.lightning.callbacks.SAM
loggers:
- class_path: pytorch_lightning.loggers.CSVLogger # can omit init_args['save_dir'] for auto directory
```



## Dataset Config
Similar to the trainer config, the dataset config has sections for debug and full experiments. Dataset paths, batch size, num workers, seed, and other relevant arguments may be set here. The available target keys for training are included. Other settings such as `normalization_kwargs` and `task_loss_scaling` may be set here under the `task_args` tag.
```yaml
dataset: CMDataset
debug:
batch_size: 4
num_workers: 0
experiment:
test_split: ''
train_path: ''
val_split: ''
target_keys:
- energy
- symmetry_number
- symmetry_symbol
task_args:
normalize_kwargs:
energy_mean: 1
energy_std: 1
task_loss_scaling:
energy: 1
```

## Model Config
Models available in matsciml my be DGL, PyG, or pointcloud based. Each model it named with its supported backend, as models may have more than one variety. In some instances, similar to the `trainer.yaml` config, a `class_path` and `init_args` need to be specified. Additionally, modules may need to be specified without initialization which may be done by using the `class_instance` tag. Finally, all transforms that a model should use should be included in the model config.
```yaml
encoder_class:
class_path: matsciml.models.TensorNet
encoder_kwargs:
element_types:
class_path: matsciml.datasets.utils.element_types
num_rbf: 32
max_n: 3
max_l: 3
output_kwargs:
lazy: False
input_dim: 64
hidden_dim: 64
transforms:
- class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform
init_args:
cutoff_radius: 6.5
adaptive_cutoff: True
- class_path: matsciml.datasets.transforms.PointCloudToGraphTransform
init_args:
backend: dgl
cutoff_dist: 20.0
node_keys:
- "pos"
- "atomic_numbers"
```

## CLI Parameter Updates
Certain parameters may be updated using the cli. The `-c, --cli_args` argument may be used, and the parameter must be specified as `[config].parameter.value`. The config may be `trainer`, `model`, or `dataset`. For example, to update the batch size for a debug run:

`python training_script.py --debug --cli_args dataset.debug.batch_size.16`

Only arguments which contain dict: [str, int, float] mapping all the way through to the target value may be updated. Parameters which map to lists at any point are not updatable through `cli_args`, for example callbacks, loggers, and transforms.
Empty file added experiments/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions experiments/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import yaml

from pathlib import Path

yaml_dir = Path(__file__).parent


available_data = {
"generic": {
"experiment": {"batch_size": 32, "num_workers": 16},
"debug": {"batch_size": 4, "num_workers": 0},
},
}


for filename in yaml_dir.rglob("*.yaml"):
file_path = yaml_dir.joinpath(filename)
with open(file_path, "r") as file:
content = yaml.safe_load(file)
file_key = file_path.stem
available_data[file_key] = content
18 changes: 18 additions & 0 deletions experiments/datasets/carolina.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
dataset: CMDataset
debug:
batch_size: 4
num_workers: 0
experiment:
test_split: ''
train_path: ''
val_split: ''
target_keys:
- energy
- symmetry_number
- symmetry_symbol
task_args:
normalize_kwargs:
energy_mean: 1
energy_std: 1
task_loss_scaling:
energy: 1
86 changes: 86 additions & 0 deletions experiments/datasets/data_module_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any
import sys

import pytorch_lightning as pl

from matsciml.lightning.data_utils import (
MatSciMLDataModule,
)
from matsciml.datasets import * # noqa: F401

from matsciml.lightning.data_utils import MultiDataModule

from experiments.datasets import available_data
from experiments.models import available_models
from experiments.utils.utils import instantiate_arg_dict, update_arg_dict


def setup_datamodule(config: dict[str, Any]) -> pl.LightningModule:
model = config["model"]
data_task_dict = config["dataset"]
run_type = config["run_type"]
model = instantiate_arg_dict(deepcopy(available_models[model]))
model = update_arg_dict("model", model, config["cli_args"])
datasets = list(data_task_dict.keys())
if len(datasets) == 1:
dset = deepcopy(available_data[datasets[0]])
dset = update_arg_dict("dataset", dset, config["cli_args"])
dm_kwargs = deepcopy(available_data["generic"]["experiment"])
dm_kwargs.update(dset[run_type])
if run_type == "debug":
dm = MatSciMLDataModule.from_devset(
dataset=dset["dataset"],
dset_kwargs={"transforms": model["transforms"]},
**dm_kwargs,
)
else:
dm = MatSciMLDataModule(
dataset=dset["dataset"],
dset_kwargs={"transforms": model["transforms"]},
**dm_kwargs,
)
else:
dset_list = {"train": [], "val": [], "test": []}
for dataset in datasets:
dset = deepcopy(available_data[dataset])
dset = update_arg_dict("dataset", dset, config["cli_args"])
dm_kwargs = deepcopy(available_data["generic"]["experiment"])
dset[run_type].pop("normalize_kwargs", None)
dm_kwargs.update(dset[run_type])
dataset_name = dset["dataset"]
dataset = getattr(sys.modules[__name__], dataset_name)
model_transforms = model["transforms"]
if run_type == "debug":
dset_list["train"].append(
dataset.from_devset(transforms=model_transforms)
)
dset_list["val"].append(
dataset.from_devset(transforms=model_transforms)
)
dset_list["test"].append(
dataset.from_devset(transforms=model_transforms)
)
else:
if "train_path" in dm_kwargs:
dset_list["train"].append(
dataset(dm_kwargs["train_path"], transforms=model_transforms)
)
if "val_split" in dm_kwargs:
dset_list["val"].append(
dataset(dm_kwargs["val_split"], transforms=model_transforms)
)
if "test_split" in dm_kwargs:
dset_list["test"].append(
dataset(dm_kwargs["test_split"], transforms=model_transforms)
)
dm = MultiDataModule(
train_dataset=MultiDataset(dset_list["train"]),
val_dataset=MultiDataset(dset_list["val"]),
test_dataset=MultiDataset(dset_list["test"]),
batch_size=dm_kwargs["batch_size"],
num_workers=dm_kwargs["num_workers"],
)
return dm
10 changes: 10 additions & 0 deletions experiments/datasets/is2re.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
dataset: IS2REDataset
debug:
batch_size: 4
num_workers: 0
experiment:
train_path: ''
val_split: ''
target_keys:
- energy_init
- energy_relaxed
11 changes: 11 additions & 0 deletions experiments/datasets/lips.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
dataset: LiPSDataset
debug:
batch_size: 4
num_workers: 0
experiment:
test_split: ''
train_path: ''
val_split: ''
target_keys:
- energy
- force
20 changes: 20 additions & 0 deletions experiments/datasets/materials_project.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
dataset: MaterialsProjectDataset
debug:
batch_size: 4
num_workers: 0
experiment:
test_split: ""
train_path: ""
val_split: ""
target_keys:
- is_magnetic
- is_metal
- is_stable
- band_gap
- efermi
- energy_per_atom
- formation_energy_per_atom
- uncorrected_energy_per_atom
- symmetry_number
- symmetry_symbol
- symmetry_group
15 changes: 15 additions & 0 deletions experiments/datasets/nomad.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataset: NomadDataset
debug:
batch_size: 4
num_workers: 0
experiment:
test_split: ''
train_path: ''
val_split: ''
target_keys:
- spin_polarized
- efermi
- relative_energy
- symmetry_number
- symmetry_symbol
- symmetry_group
12 changes: 12 additions & 0 deletions experiments/datasets/oqmd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
dataset: OQMDDataset
debug:
batch_size: 4
num_workers: 0
experiment:
test_split: ''
train_path: ''
val_split: ''
target_keys:
- band_gap
- energy
- stability
10 changes: 10 additions & 0 deletions experiments/datasets/s2ef.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
dataset: S2EFDataset
debug:
batch_size: 4
num_workers: 0
experiment:
train_path: ''
val_split: ''
target_keys:
- energy
- force
41 changes: 41 additions & 0 deletions experiments/datasets/tests/test_data_module_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations


import pytest

import matsciml
import matsciml.datasets.transforms # noqa: F401
from experiments.datasets.data_module_config import setup_datamodule


single_task = {
"model": "egnn_dgl",
"dataset": {"oqmd": [{"task": "ScalarRegressionTask", "targets": ["band_gap"]}]},
}
multi_task = {
"dataset": {
"s2ef": [
{"task": "ScalarRegressionTask", "targets": ["energy"]},
{"task": "ForceRegressionTask", "targets": ["force"]},
]
}
}
multi_data = {
"model": "faenet_pyg",
"dataset": {
"oqmd": [{"task": "ScalarRegressionTask", "targets": ["energy"]}],
"is2re": [
{
"task": "ScalarRegressionTask",
"targets": ["energy_init", "energy_relaxed"],
}
],
},
}


@pytest.mark.parametrize("task_dict", [single_task, multi_task, multi_data])
def test_task_setup(task_dict):
other_args = {"run_type": "debug", "model": "m3gnet_dgl", "cli_args": None}
task_dict.update(other_args)
setup_datamodule(config=task_dict)
6 changes: 6 additions & 0 deletions experiments/experiment_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model: faenet_pyg
dataset:
oqmd:
- task: ScalarRegressionTask
targets:
- energy
Loading

0 comments on commit 651a2b5

Please sign in to comment.