diff --git a/experiments/README.md b/experiments/README.md new file mode 100644 index 00000000..e7029260 --- /dev/null +++ b/experiments/README.md @@ -0,0 +1,127 @@ +# Experiments in Open MatSciML Toolkit + +Experimental workflows may be time consuming, repetitive, and complex to set up. Additionally, pytorch-lightning based cli utilities may not be able to handle specific use cases such as multi-data or multi-task training in matsciml. The experiments module of MatSciML is meant to loosely mirror the functionality of the pytorch lightning cli while allowing more flexibility in setting up complex experiments. Yaml files define the module parameters, and specific arguments may be change via the command line if desired. A single command is used to launch training runs which take out the complexity of writing up new script for each experiment type. + +## Experiment Config +The starting point of defining an experiment is the experiment config. This is a yaml file that lays out what model, dataset(s), and task(s) will be used during training. An example config for single task training yaml (`single_task.yaml`) look like this: +```yaml +model: egnn_dgl +dataset: + oqmd: + scalar_regression: + - energy +``` + +In general, and experiment may the be launched by running: +`python experiments/training_script.py --experiment_config ./experiments/single_task.yaml` + + +* The `model` field points to a specify `model.yaml` file in `./experiments/models`. +* The `dataset` field is a dictionary specifying which datasets to use, as well as which tasks are associated with the parent dataset. + * Tasks are referred to by their class name: + ```python + ScalarRegressionTask + ForceRegressionTask + BinaryClassificationTask + CrystalSymmetryClassificationTask + GradFreeForceRegressionTask + ``` +* A dataset may contain more than one task (single data, multi task learning) +* Multiple datasets can be provided, each containing their own tasks (multi data, multi task learning) +* For a list of available datasets, tasks, and models run `python training_script.py --options`. + +## Trainer Config +The training config contains a few sections used for defining how experiments will be run. The debug tag is used to set parameters that should be used when debugging an experimental setup, or when working through bugs in setting up a new model or dataset. These parameters are helpful for doing quick end-to-end runs to make sure the pipeline is functional. The experiment tag is used to define parameters for the full experiment runs. Finally the generic tag used to define parameters used regardless of going through a debug or full experimental run. + +In addition to the experiment types, any other parameters to be used with the pytorch lightning `Trainer` should be added here. In the example `trainer.yaml`, there are callbacks and a logger. These objects are set up by adding their `class_path` as well as any `init_args` they expect. +```yaml +generic: + min_epochs: 15 + max_epochs: 100 +debug: + accelerator: cpu + limit_train_batches: 10 + limit_val_batches: 10 + log_every_n_steps: 1 + max_epochs: 2 +experiment: + accelerator: gpu + strategy: ddp_find_unused_parameters_true +callbacks: + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + patience: 5 + monitor: val_energy + mode: min + verbose: True + check_finite: False + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + monitor: val_energy + save_top_k: 3 + - class_path: matsciml.lightning.callbacks.GradientCheckCallback + - class_path: matsciml.lightning.callbacks.SAM +loggers: + - class_path: pytorch_lightning.loggers.CSVLogger # can omit init_args['save_dir'] for auto directory +``` + + + +## Dataset Config +Similar to the trainer config, the dataset config has sections for debug and full experiments. Dataset paths, batch size, num workers, seed, and other relevant arguments may be set here. The available target keys for training are included. Other settings such as `normalization_kwargs` and `task_loss_scaling` may be set here under the `task_args` tag. +```yaml +dataset: CMDataset +debug: + batch_size: 4 + num_workers: 0 +experiment: + test_split: '' + train_path: '' + val_split: '' +target_keys: +- energy +- symmetry_number +- symmetry_symbol +task_args: + normalize_kwargs: + energy_mean: 1 + energy_std: 1 + task_loss_scaling: + energy: 1 +``` + +## Model Config +Models available in matsciml my be DGL, PyG, or pointcloud based. Each model it named with its supported backend, as models may have more than one variety. In some instances, similar to the `trainer.yaml` config, a `class_path` and `init_args` need to be specified. Additionally, modules may need to be specified without initialization which may be done by using the `class_instance` tag. Finally, all transforms that a model should use should be included in the model config. +```yaml +encoder_class: + class_path: matsciml.models.TensorNet +encoder_kwargs: + element_types: + class_path: matsciml.datasets.utils.element_types + num_rbf: 32 + max_n: 3 + max_l: 3 +output_kwargs: + lazy: False + input_dim: 64 + hidden_dim: 64 +transforms: + - class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform + init_args: + cutoff_radius: 6.5 + adaptive_cutoff: True + - class_path: matsciml.datasets.transforms.PointCloudToGraphTransform + init_args: + backend: dgl + cutoff_dist: 20.0 + node_keys: + - "pos" + - "atomic_numbers" +``` + +## CLI Parameter Updates +Certain parameters may be updated using the cli. The `-c, --cli_args` argument may be used, and the parameter must be specified as `[config].parameter.value`. The config may be `trainer`, `model`, or `dataset`. For example, to update the batch size for a debug run: + +`python training_script.py --debug --cli_args dataset.debug.batch_size.16` + + Only arguments which contain dict: [str, int, float] mapping all the way through to the target value may be updated. Parameters which map to lists at any point are not updatable through `cli_args`, for example callbacks, loggers, and transforms. diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/datasets/__init__.py b/experiments/datasets/__init__.py new file mode 100644 index 00000000..dbdc92b5 --- /dev/null +++ b/experiments/datasets/__init__.py @@ -0,0 +1,21 @@ +import yaml + +from pathlib import Path + +yaml_dir = Path(__file__).parent + + +available_data = { + "generic": { + "experiment": {"batch_size": 32, "num_workers": 16}, + "debug": {"batch_size": 4, "num_workers": 0}, + }, +} + + +for filename in yaml_dir.rglob("*.yaml"): + file_path = yaml_dir.joinpath(filename) + with open(file_path, "r") as file: + content = yaml.safe_load(file) + file_key = file_path.stem + available_data[file_key] = content diff --git a/experiments/datasets/carolina.yaml b/experiments/datasets/carolina.yaml new file mode 100644 index 00000000..15376eb3 --- /dev/null +++ b/experiments/datasets/carolina.yaml @@ -0,0 +1,18 @@ +dataset: CMDataset +debug: + batch_size: 4 + num_workers: 0 +experiment: + test_split: '' + train_path: '' + val_split: '' +target_keys: +- energy +- symmetry_number +- symmetry_symbol +task_args: + normalize_kwargs: + energy_mean: 1 + energy_std: 1 + task_loss_scaling: + energy: 1 diff --git a/experiments/datasets/data_module_config.py b/experiments/datasets/data_module_config.py new file mode 100644 index 00000000..08dc7273 --- /dev/null +++ b/experiments/datasets/data_module_config.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any +import sys + +import pytorch_lightning as pl + +from matsciml.lightning.data_utils import ( + MatSciMLDataModule, +) +from matsciml.datasets import * # noqa: F401 + +from matsciml.lightning.data_utils import MultiDataModule + +from experiments.datasets import available_data +from experiments.models import available_models +from experiments.utils.utils import instantiate_arg_dict, update_arg_dict + + +def setup_datamodule(config: dict[str, Any]) -> pl.LightningModule: + model = config["model"] + data_task_dict = config["dataset"] + run_type = config["run_type"] + model = instantiate_arg_dict(deepcopy(available_models[model])) + model = update_arg_dict("model", model, config["cli_args"]) + datasets = list(data_task_dict.keys()) + if len(datasets) == 1: + dset = deepcopy(available_data[datasets[0]]) + dset = update_arg_dict("dataset", dset, config["cli_args"]) + dm_kwargs = deepcopy(available_data["generic"]["experiment"]) + dm_kwargs.update(dset[run_type]) + if run_type == "debug": + dm = MatSciMLDataModule.from_devset( + dataset=dset["dataset"], + dset_kwargs={"transforms": model["transforms"]}, + **dm_kwargs, + ) + else: + dm = MatSciMLDataModule( + dataset=dset["dataset"], + dset_kwargs={"transforms": model["transforms"]}, + **dm_kwargs, + ) + else: + dset_list = {"train": [], "val": [], "test": []} + for dataset in datasets: + dset = deepcopy(available_data[dataset]) + dset = update_arg_dict("dataset", dset, config["cli_args"]) + dm_kwargs = deepcopy(available_data["generic"]["experiment"]) + dset[run_type].pop("normalize_kwargs", None) + dm_kwargs.update(dset[run_type]) + dataset_name = dset["dataset"] + dataset = getattr(sys.modules[__name__], dataset_name) + model_transforms = model["transforms"] + if run_type == "debug": + dset_list["train"].append( + dataset.from_devset(transforms=model_transforms) + ) + dset_list["val"].append( + dataset.from_devset(transforms=model_transforms) + ) + dset_list["test"].append( + dataset.from_devset(transforms=model_transforms) + ) + else: + if "train_path" in dm_kwargs: + dset_list["train"].append( + dataset(dm_kwargs["train_path"], transforms=model_transforms) + ) + if "val_split" in dm_kwargs: + dset_list["val"].append( + dataset(dm_kwargs["val_split"], transforms=model_transforms) + ) + if "test_split" in dm_kwargs: + dset_list["test"].append( + dataset(dm_kwargs["test_split"], transforms=model_transforms) + ) + dm = MultiDataModule( + train_dataset=MultiDataset(dset_list["train"]), + val_dataset=MultiDataset(dset_list["val"]), + test_dataset=MultiDataset(dset_list["test"]), + batch_size=dm_kwargs["batch_size"], + num_workers=dm_kwargs["num_workers"], + ) + return dm diff --git a/experiments/datasets/is2re.yaml b/experiments/datasets/is2re.yaml new file mode 100644 index 00000000..298cc449 --- /dev/null +++ b/experiments/datasets/is2re.yaml @@ -0,0 +1,10 @@ +dataset: IS2REDataset +debug: + batch_size: 4 + num_workers: 0 +experiment: + train_path: '' + val_split: '' +target_keys: +- energy_init +- energy_relaxed diff --git a/experiments/datasets/lips.yaml b/experiments/datasets/lips.yaml new file mode 100644 index 00000000..5583a692 --- /dev/null +++ b/experiments/datasets/lips.yaml @@ -0,0 +1,11 @@ +dataset: LiPSDataset +debug: + batch_size: 4 + num_workers: 0 +experiment: + test_split: '' + train_path: '' + val_split: '' +target_keys: +- energy +- force diff --git a/experiments/datasets/materials_project.yaml b/experiments/datasets/materials_project.yaml new file mode 100644 index 00000000..9abb5b72 --- /dev/null +++ b/experiments/datasets/materials_project.yaml @@ -0,0 +1,20 @@ +dataset: MaterialsProjectDataset +debug: + batch_size: 4 + num_workers: 0 +experiment: + test_split: "" + train_path: "" + val_split: "" +target_keys: +- is_magnetic +- is_metal +- is_stable +- band_gap +- efermi +- energy_per_atom +- formation_energy_per_atom +- uncorrected_energy_per_atom +- symmetry_number +- symmetry_symbol +- symmetry_group diff --git a/experiments/datasets/nomad.yaml b/experiments/datasets/nomad.yaml new file mode 100644 index 00000000..803cac9a --- /dev/null +++ b/experiments/datasets/nomad.yaml @@ -0,0 +1,15 @@ +dataset: NomadDataset +debug: + batch_size: 4 + num_workers: 0 +experiment: + test_split: '' + train_path: '' + val_split: '' +target_keys: +- spin_polarized +- efermi +- relative_energy +- symmetry_number +- symmetry_symbol +- symmetry_group diff --git a/experiments/datasets/oqmd.yaml b/experiments/datasets/oqmd.yaml new file mode 100644 index 00000000..42d494f3 --- /dev/null +++ b/experiments/datasets/oqmd.yaml @@ -0,0 +1,12 @@ +dataset: OQMDDataset +debug: + batch_size: 4 + num_workers: 0 +experiment: + test_split: '' + train_path: '' + val_split: '' +target_keys: +- band_gap +- energy +- stability diff --git a/experiments/datasets/s2ef.yaml b/experiments/datasets/s2ef.yaml new file mode 100644 index 00000000..f6d8cf95 --- /dev/null +++ b/experiments/datasets/s2ef.yaml @@ -0,0 +1,10 @@ +dataset: S2EFDataset +debug: + batch_size: 4 + num_workers: 0 +experiment: + train_path: '' + val_split: '' +target_keys: +- energy +- force diff --git a/experiments/datasets/tests/test_data_module_creation.py b/experiments/datasets/tests/test_data_module_creation.py new file mode 100644 index 00000000..7bd94be2 --- /dev/null +++ b/experiments/datasets/tests/test_data_module_creation.py @@ -0,0 +1,41 @@ +from __future__ import annotations + + +import pytest + +import matsciml +import matsciml.datasets.transforms # noqa: F401 +from experiments.datasets.data_module_config import setup_datamodule + + +single_task = { + "model": "egnn_dgl", + "dataset": {"oqmd": [{"task": "ScalarRegressionTask", "targets": ["band_gap"]}]}, +} +multi_task = { + "dataset": { + "s2ef": [ + {"task": "ScalarRegressionTask", "targets": ["energy"]}, + {"task": "ForceRegressionTask", "targets": ["force"]}, + ] + } +} +multi_data = { + "model": "faenet_pyg", + "dataset": { + "oqmd": [{"task": "ScalarRegressionTask", "targets": ["energy"]}], + "is2re": [ + { + "task": "ScalarRegressionTask", + "targets": ["energy_init", "energy_relaxed"], + } + ], + }, +} + + +@pytest.mark.parametrize("task_dict", [single_task, multi_task, multi_data]) +def test_task_setup(task_dict): + other_args = {"run_type": "debug", "model": "m3gnet_dgl", "cli_args": None} + task_dict.update(other_args) + setup_datamodule(config=task_dict) diff --git a/experiments/experiment_config.yaml b/experiments/experiment_config.yaml new file mode 100644 index 00000000..9b390a52 --- /dev/null +++ b/experiments/experiment_config.yaml @@ -0,0 +1,6 @@ +model: faenet_pyg +dataset: + oqmd: + - task: ScalarRegressionTask + targets: + - energy diff --git a/experiments/models/__init__.py b/experiments/models/__init__.py new file mode 100644 index 00000000..3631dce6 --- /dev/null +++ b/experiments/models/__init__.py @@ -0,0 +1,27 @@ +import yaml + +from torch.nn import LayerNorm + + +from pathlib import Path + +yaml_dir = Path(__file__).parent +available_models = { + "generic": { + "output_kwargs": { + "norm": LayerNorm(128), + "hidden_dim": 128, + "activation": "SiLU", + "lazy": False, + "input_dim": 128, + }, + "lr": 0.0001, + }, +} + +for filename in yaml_dir.rglob("*.yaml"): + file_path = yaml_dir.joinpath(filename) + with open(file_path, "r") as file: + content = yaml.safe_load(file) + file_key = file_path.stem + available_models[file_key] = content diff --git a/experiments/models/egnn_dgl.yaml b/experiments/models/egnn_dgl.yaml new file mode 100644 index 00000000..51247700 --- /dev/null +++ b/experiments/models/egnn_dgl.yaml @@ -0,0 +1,56 @@ +encoder_class: + class_path: matsciml.models.PLEGNNBackbone +encoder_kwargs: + embed_activate_last: false + embed_activation: relu + embed_attention_norm: sigmoid + embed_depth: 5 + embed_edge_attributes_dim: 0 + embed_feat_dims: + - 128 + - 128 + - 128 + embed_hidden_dim: 32 + embed_in_dim: 1 + embed_k_linears: 1 + embed_message_dims: + - 128 + - 128 + - 128 + embed_normalize: true + embed_out_dim: 128 + embed_position_dims: + - 64 + - 64 + embed_residual: true + embed_tanh: true + embed_use_attention: false + node_projection_activation: relu + node_projection_depth: 3 + node_projection_hidden_dim: 128 + prediction_activation: relu + prediction_depth: 3 + prediction_hidden_dim: 128 + prediction_out_dim: 1 + readout: sum +output_kwargs: + lazy: False + norm: + class_path: torch.nn.LayerNorm + init_args: + normalized_shape: 128 + activation: torch.nn.SiLU + input_dim: 128 + hidden_dim: 128 +transforms: + - class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform + init_args: + cutoff_radius: 6.5 + adaptive_cutoff: True + - class_path: matsciml.datasets.transforms.PointCloudToGraphTransform + init_args: + backend: dgl + cutoff_dist: 20.0 + node_keys: + - "pos" + - "atomic_numbers" diff --git a/experiments/models/faenet_pyg.yaml b/experiments/models/faenet_pyg.yaml new file mode 100644 index 00000000..5950f146 --- /dev/null +++ b/experiments/models/faenet_pyg.yaml @@ -0,0 +1,30 @@ +encoder_class: + class_path: matsciml.models.FAENet +encoder_kwargs: + act: silu + cutoff: 6.0 + average_frame_embeddings: False + pred_as_dict: False + hidden_dim: 128 + out_dim: 128 + tag_hidden_channels: 0 +output_kwargs: + lazy: False + input_dim: 128 + hidden_dim: 128 +transforms: + - class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform + init_args: + cutoff_radius: 6.5 + adaptive_cutoff: True + - class_path: matsciml.datasets.transforms.PointCloudToGraphTransform + init_args: + backend: pyg + cutoff_dist: 20.0 + node_keys: + - "pos" + - "atomic_numbers" + - class_path: matsciml.datasets.transforms.FrameAveraging + init_args: + frame_averaging: 3D + fa_method: stochastic diff --git a/experiments/models/gala_pc.yaml b/experiments/models/gala_pc.yaml new file mode 100644 index 00000000..4efdf956 --- /dev/null +++ b/experiments/models/gala_pc.yaml @@ -0,0 +1,25 @@ +encoder_class: + class_path: matsciml.models.GalaPotential +encoder_kwargs: + D_in: 100 + depth: 2 + hidden_dim: 64 + merge_fun: concat + join_fun: concat + invariant_mode: full + covariant_mode: full + include_normalized_products: True + invar_value_normalization: momentum + eqvar_value_normalization: momentum_layer + value_normalization: layer + score_normalization: layer + block_normalization: layer + equivariant_attention: False + tied_attention: True + encoder_only: True +output_kwargs: + lazy: False + input_dim: 64 + hidden_dim: 64 +transforms: + - class_path: matsciml.datasets.transforms.COMShift diff --git a/experiments/models/m3gnet_dgl.yaml b/experiments/models/m3gnet_dgl.yaml new file mode 100644 index 00000000..daacfa51 --- /dev/null +++ b/experiments/models/m3gnet_dgl.yaml @@ -0,0 +1,23 @@ +encoder_class: + class_path: matsciml.models.M3GNet +encoder_kwargs: + element_types: + class_path: matsciml.datasets.utils.element_types + return_all_layer_output: True +output_kwargs: + lazy: False + input_dim: 64 + hidden_dim: 64 +transforms: + - class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform + init_args: + cutoff_radius: 6.5 + adaptive_cutoff: True + - class_path: matsciml.datasets.transforms.PointCloudToGraphTransform + init_args: + backend: dgl + cutoff_dist: 20.0 + node_keys: + - "pos" + - "atomic_numbers" + - class_path: matsciml.datasets.transforms.MGLDataTransform diff --git a/experiments/models/mace_pyg.yaml b/experiments/models/mace_pyg.yaml new file mode 100644 index 00000000..74cf462f --- /dev/null +++ b/experiments/models/mace_pyg.yaml @@ -0,0 +1,136 @@ +encoder_class: + class_path: matsciml.models.pyg.mace.MACEWrapper +encoder_kwargs: + mace_module: + class_instance: mace.modules.ScaleShiftMACE + MLP_irreps: + class_path: e3nn.o3.Irreps + init_args: + irreps: "16x0e" + atom_embedding_dim: 128 + atomic_inter_scale: 0.8042 + atomic_inter_shift: 0.1641 + avg_num_neighbors: 61.96 + correlation: 3 + gate: + class_path: torch.nn.SiLU + interaction_cls: + class_instance: mace.modules.blocks.RealAgnosticResidualInteractionBlock + interaction_cls_first: + class_instance: mace.modules.blocks.RealAgnosticResidualInteractionBlock + max_ell: 3 + num_atom_embedding: 89 + num_bessel: 10 + num_interactions: 2 + num_polynomial_cutoff: 5.0 + r_max: 6.0 + radial_type: bessel + atomic_energies: + class_path: torch.Tensor + init_args: + data: + - -3.6672 + - -1.3321 + - -3.4821 + - -4.7367 + - -7.7249 + - -8.4056 + - -7.3601 + - -7.2846 + - -4.8965 + - 0.0 + - -2.7594 + - -2.814 + - -4.8469 + - -7.6948 + - -6.9633 + - -4.6726 + - -2.8117 + - -0.0626 + - -2.6176 + - -5.3905 + - -7.8858 + - -10.2684 + - -8.6651 + - -9.2331 + - -8.305 + - -7.049 + - -5.5774 + - -5.1727 + - -3.2521 + - -1.2902 + - -3.5271 + - -4.7085 + - -3.9765 + - -3.8862 + - -2.5185 + - 6.7669 + - -2.5635 + - -4.938 + - -10.1498 + - -11.8469 + - -12.1389 + - -8.7917 + - -8.7869 + - -7.7809 + - -6.85 + - -4.891 + - -2.0634 + - -0.6396 + - -2.7887 + - -3.8186 + - -3.5871 + - -2.8804 + - -1.6356 + - 9.8467 + - -2.7653 + - -4.991 + - -8.9337 + - -8.7356 + - -8.019 + - -8.2515 + - -7.5917 + - -8.1697 + - -13.5927 + - -18.5175 + - -7.6474 + - -8.123 + - -7.6078 + - -6.8503 + - -7.8269 + - -3.5848 + - -7.4554 + - -12.7963 + - -14.1081 + - -9.3549 + - -11.3875 + - -9.6219 + - -7.3244 + - -5.3047 + - -2.3801 + - 0.2495 + - -2.324 + - -3.73 + - -3.4388 + - -5.0629 + - -11.0246 + - -12.2656 + - -13.8556 + - -14.9331 + - -15.2828 +output_kwargs: + lazy: False + input_dim: 256 + hidden_dim: 256 +transforms: + - class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform + init_args: + cutoff_radius: 6.5 + adaptive_cutoff: True + - class_path: matsciml.datasets.transforms.PointCloudToGraphTransform + init_args: + backend: pyg + cutoff_dist: 20.0 + node_keys: + - "pos" + - "atomic_numbers" diff --git a/experiments/models/megnet_dgl.yaml b/experiments/models/megnet_dgl.yaml new file mode 100644 index 00000000..eff02897 --- /dev/null +++ b/experiments/models/megnet_dgl.yaml @@ -0,0 +1,40 @@ +encoder_class: + class_path: matsciml.models.MEGNet +encoder_kwargs: + conv_hiddens: + - 128 + - 128 + - 128 + edge_feat_dim: 2 + encoder_only: true + graph_feat_dim: 9 + hiddens: + - 256 + - 256 + - 128 + is_classification: false + node_feat_dim: 128 + num_blocks: 4 + output_hiddens: + - 64 + - 64 + s2s_num_iters: 4 + s2s_num_layers: 5 +output_kwargs: + lazy: False + input_dim: 640 + hidden_dim: 640 +transforms: + - class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform + init_args: + cutoff_radius: 6.5 + adaptive_cutoff: True + - class_path: matsciml.datasets.transforms.PointCloudToGraphTransform + init_args: + backend: dgl + cutoff_dist: 20.0 + node_keys: + - "pos" + - "atomic_numbers" + - class_path: matsciml.datasets.transforms.DistancesTransform + - class_path: matsciml.datasets.transforms.GraphVariablesTransform diff --git a/experiments/models/tensornet_dgl.yaml b/experiments/models/tensornet_dgl.yaml new file mode 100644 index 00000000..6d08e5ac --- /dev/null +++ b/experiments/models/tensornet_dgl.yaml @@ -0,0 +1,24 @@ +encoder_class: + class_path: matsciml.models.TensorNet +encoder_kwargs: + element_types: + class_path: matsciml.datasets.utils.element_types + num_rbf: 32 + max_n: 3 + max_l: 3 +output_kwargs: + lazy: False + input_dim: 64 + hidden_dim: 64 +transforms: + - class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform + init_args: + cutoff_radius: 6.5 + adaptive_cutoff: True + - class_path: matsciml.datasets.transforms.PointCloudToGraphTransform + init_args: + backend: dgl + cutoff_dist: 20.0 + node_keys: + - "pos" + - "atomic_numbers" diff --git a/experiments/multi_data_multi_task.yaml b/experiments/multi_data_multi_task.yaml new file mode 100644 index 00000000..8cb7dd6a --- /dev/null +++ b/experiments/multi_data_multi_task.yaml @@ -0,0 +1,11 @@ +model: faenet_pyg +dataset: + oqmd: + - task: ScalarRegressionTask + targets: + - energy + is2re: + - task: ScalarRegressionTask + targets: + - energy_init + - energy_relaxed diff --git a/experiments/single_data_single_task.yaml b/experiments/single_data_single_task.yaml new file mode 100644 index 00000000..cb8a7760 --- /dev/null +++ b/experiments/single_data_single_task.yaml @@ -0,0 +1,6 @@ +model: egnn_dgl +dataset: + oqmd: + - task: ScalarRegressionTask + targets: + - band_gap diff --git a/experiments/task_config/__init__.py b/experiments/task_config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/task_config/task_config.py b/experiments/task_config/task_config.py new file mode 100644 index 00000000..26e4b7a8 --- /dev/null +++ b/experiments/task_config/task_config.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +import pytorch_lightning as pl + +from matsciml.common.registry import registry +from matsciml.models.base import MultiTaskLitModule + +from experiments.datasets import available_data +from experiments.models import available_models +from experiments.utils.utils import instantiate_arg_dict, update_arg_dict + + +def setup_task(config: dict[str, Any]) -> pl.LightningModule: + model = config["model"] + data_task_dict = config["dataset"] + model = instantiate_arg_dict(deepcopy(available_models[model])) + model = update_arg_dict("model", model, config["cli_args"]) + configured_tasks = [] + data_task_list = [] + for dataset_name, tasks in data_task_dict.items(): + dset_args = deepcopy(available_data[dataset_name]) + dset_args = update_arg_dict("dataset", dset_args, config["cli_args"]) + for task in tasks: + task_class = registry.get_task_class(task["task"]) + task_args = deepcopy(available_models["generic"]) + task_args.update(model) + task_args.update({"task_keys": task["targets"]}) + additonal_task_args = dset_args.get("task_args", None) + if additonal_task_args is not None: + task_args.update(additonal_task_args) + configured_task = task_class(**task_args) + configured_tasks.append(configured_task) + data_task_list.append( + [available_data[dataset_name]["dataset"], configured_task] + ) + + if len(configured_tasks) > 1: + task = MultiTaskLitModule(*data_task_list) + else: + task = configured_tasks[0] + return task diff --git a/experiments/task_config/tests/test_task_creation.py b/experiments/task_config/tests/test_task_creation.py new file mode 100644 index 00000000..217814fa --- /dev/null +++ b/experiments/task_config/tests/test_task_creation.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Any + +import pytest +from experiments.task_config.task_config import ( + setup_task, +) +from experiments.utils.utils import instantiate_arg_dict + +import matsciml +import matsciml.datasets.transforms # noqa: F401 + + +single_task = { + "model": "egnn_dgl", + "dataset": {"oqmd": [{"task": "ScalarRegressionTask", "targets": ["band_gap"]}]}, +} +multi_task = { + "dataset": { + "s2ef": [ + {"task": "ScalarRegressionTask", "targets": ["energy"]}, + {"task": "ForceRegressionTask", "targets": ["force"]}, + ] + } +} +multi_data = { + "model": "faenet_pyg", + "dataset": { + "oqmd": [{"task": "ScalarRegressionTask", "targets": ["energy"]}], + "is2re": [ + { + "task": "ScalarRegressionTask", + "targets": ["energy_init", "energy_relaxed"], + } + ], + }, +} + + +@pytest.fixture +def test_build_model() -> dict[str, Any]: + input_dict = { + "encoder_class": {"class_path": "matsciml.models.M3GNet"}, + "encoder_kwargs": { + "element_types": {"class_path": "matsciml.datasets.utils.element_types"}, + "return_all_layer_output": True, + }, + "output_kwargs": {"lazy": False, "input_dim": 64, "hidden_dim": 64}, + "transforms": [ + { + "class_path": "matsciml.datasets.transforms.PeriodicPropertiesTransform", + "init_args": [{"cutoff_radius": 6.5}, {"adaptive_cutoff": True}], + }, + { + "class_path": "matsciml.datasets.transforms.PointCloudToGraphTransform", + "init_args": [ + {"backend": "dgl"}, + {"cutoff_dist": 20.0}, + {"node_keys": ["pos", "atomic_numbers"]}, + ], + }, + {"class_path": "matsciml.datasets.transforms.MGLDataTransform"}, + ], + } + + output = instantiate_arg_dict(input_dict) + assert isinstance( + output["transforms"][0], + matsciml.datasets.transforms.PeriodicPropertiesTransform, + ) + assert isinstance( + output["transforms"][1], + matsciml.datasets.transforms.PointCloudToGraphTransform, + ) + return output + + +@pytest.mark.dependency(depends=["test_build_model"]) +@pytest.mark.parametrize("task_dict", [single_task, multi_task, multi_data]) +def test_task_setup(task_dict): + other_args = {"run_type": "debug", "model": "m3gnet_dgl", "cli_args": None} + task_dict.update(other_args) + setup_task(config=task_dict) diff --git a/experiments/trainer_config/__init__.py b/experiments/trainer_config/__init__.py new file mode 100644 index 00000000..2eb5d981 --- /dev/null +++ b/experiments/trainer_config/__init__.py @@ -0,0 +1,17 @@ +from experiments.trainer_config.trainer_config import setup_trainer # noqa: F401 + +import yaml +from pathlib import Path + + +yaml_dir = Path(__file__).parent +trainer_args = { + "generic": {"min_epochs": 15, "max_epochs": 100}, +} + +for filename in yaml_dir.rglob("*.yaml"): + file_path = yaml_dir.joinpath(filename) + with open(file_path, "r") as file: + content = yaml.safe_load(file) + file_key = file_path.stem + trainer_args.update(content) diff --git a/experiments/trainer_config/tests/test_trainer_setup.py b/experiments/trainer_config/tests/test_trainer_setup.py new file mode 100644 index 00000000..f6e98914 --- /dev/null +++ b/experiments/trainer_config/tests/test_trainer_setup.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import pytest +import tempfile + +from pytorch_lightning.loggers import CSVLogger +from pytorch_lightning.callbacks import EarlyStopping + +from experiments.trainer_config import setup_trainer + + +@pytest.fixture +def trainer_args() -> dict: + args = { + "debug": { + "accelerator": "cpu", + "limit_train_batches": 10, + "limit_val_batches": 10, + "log_every_n_steps": 1, + "max_epochs": 2, + }, + "experiment": { + "accelerator": "gpu", + "strategy": "ddp_find_unused_parameters_true", + }, + "generic": {"min_epochs": 15, "max_epochs": 100}, + "callbacks": [ + { + "class_path": "pytorch_lightning.callbacks.EarlyStopping", + "init_args": [ + {"patience": 5}, + {"mode": "min"}, + {"verbose": True}, + {"check_finite": False}, + {"monitor": "val_energy"}, + ], + }, + { + "class_path": "pytorch_lightning.callbacks.ModelCheckpoint", + "init_args": [{"save_top_k": 3}, {"monitor": "val_energy"}], + }, + ], + "loggers": [ + { + "class_path": "pytorch_lightning.loggers.CSVLogger", + "init_args": {"save_dir": "./temp"}, + }, + ], + } + return args + + +@pytest.mark.dependency(depends=["trainer_args"]) +def test_trainer_setup(trainer_args): + temp_dir = tempfile.TemporaryDirectory() + trainer = setup_trainer( + {"run_type": "debug", "log_path": f"{temp_dir}", "cli_args": None}, trainer_args + ) + assert any([CSVLogger == logger.__class__ for logger in trainer.loggers]) + assert any([EarlyStopping == logger.__class__ for logger in trainer.callbacks]) + assert trainer.max_epochs == 2 + temp_dir.cleanup() diff --git a/experiments/trainer_config/trainer.yaml b/experiments/trainer_config/trainer.yaml new file mode 100644 index 00000000..840744cf --- /dev/null +++ b/experiments/trainer_config/trainer.yaml @@ -0,0 +1,28 @@ +generic: + min_epochs: 15 + max_epochs: 100 +debug: + accelerator: cpu + limit_train_batches: 10 + limit_val_batches: 10 + log_every_n_steps: 1 + max_epochs: 2 +experiment: + accelerator: gpu + strategy: ddp_find_unused_parameters_true +callbacks: + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + patience: 5 + monitor: val_energy + mode: min + verbose: True + check_finite: False + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + monitor: val_energy + save_top_k: 3 + - class_path: matsciml.lightning.callbacks.GradientCheckCallback + - class_path: matsciml.lightning.callbacks.SAM +loggers: + - class_path: pytorch_lightning.loggers.CSVLogger # can omit init_args['save_dir'] for auto directory diff --git a/experiments/trainer_config/trainer_config.py b/experiments/trainer_config/trainer_config.py new file mode 100644 index 00000000..cbe0e73c --- /dev/null +++ b/experiments/trainer_config/trainer_config.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +import pytorch_lightning as pl + +from experiments.utils.utils import instantiate_arg_dict, update_arg_dict + + +def setup_extra_trainer_args( + log_path: str, trainer_args: dict[str, Any] +) -> dict[str, Any]: + if "loggers" in trainer_args: + for logger in trainer_args["loggers"]: + if "CSVLogger" in logger["class_path"]: + logger.setdefault("init_args", {}) + if "save_dir" not in logger["init_args"]: + logger["init_args"].update({"save_dir": log_path}) + if "WandbLogger" in logger["class_path"]: + logger.setdefault("init_args", {}) + if "name" not in logger["init_args"]: + logger["init_args"].update({"name": log_path}) + return trainer_args + + +def setup_trainer( + config: dict[str, Any], trainer_args: dict[str, Any] +) -> pl.LightningModule: + run_type = config["run_type"] + trainer_args = setup_extra_trainer_args(config["log_path"], trainer_args) + trainer_args = instantiate_arg_dict(deepcopy(trainer_args)) + trainer_args = update_arg_dict("trainer", trainer_args, config["cli_args"]) + if "loggers" in trainer_args: + loggers = [] + for logger in trainer_args["loggers"]: + loggers.append(logger) + trainer_args.pop("loggers") + if "callbacks" in trainer_args: + callbacks = [] + for callback in trainer_args["callbacks"]: + callbacks.append(callback) + trainer_args.pop("callbacks") + + trainer_kwargs = trainer_args["generic"] + trainer_kwargs.update(trainer_args[run_type]) + trainer = pl.Trainer(logger=loggers, callbacks=callbacks, **trainer_kwargs) + return trainer diff --git a/experiments/training_script.py b/experiments/training_script.py new file mode 100644 index 00000000..af2be1a3 --- /dev/null +++ b/experiments/training_script.py @@ -0,0 +1,61 @@ +import os +import yaml +from typing import Any + +from experiments.datasets.data_module_config import setup_datamodule +from experiments.task_config.task_config import setup_task +from experiments.trainer_config.trainer_config import setup_trainer +from experiments.trainer_config import trainer_args + +from experiments.utils.utils import setup_log_dir, config_help + +from argparse import ArgumentParser + + +def main(config: dict[str, Any]) -> None: + os.makedirs(config["log_path"], exist_ok=True) + + dm = setup_datamodule(config) + task = setup_task(config) + trainer = setup_trainer(config, trainer_args=trainer_args) + trainer.fit(task, datamodule=dm) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "-o", + "--options", + help="Show options for models, datasets, and targets", + action="store_true", + ) + parser.add_argument( + "-d", + "--debug", + help="Uses debug config with devsets and only a few batches per epoch.", + action="store_true", + ) + parser.add_argument( + "-e", + "--experiment_config", + help="Experiment config yaml file to use.", + ) + parser.add_argument( + "-c", + "--cli_args", + nargs="+", + help="Parameters to update via cli, such as: dataset.debug.batch_size.16", + default=None, + ) + args = parser.parse_args() + if args.options: + config_help() + os._exit(0) + config = yaml.safe_load(open(args.experiment_config)) + config["cli_args"] = ( + [arg.split(".") for arg in args.cli_args] if args.cli_args else None + ) + log_path = setup_log_dir(config) + config["log_path"] = log_path + config["run_type"] = run_type = "debug" if args.debug else "experiment" + main(config) diff --git a/experiments/utils/__init__.py b/experiments/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/utils/utils.py b/experiments/utils/utils.py new file mode 100644 index 00000000..a3f3d7ac --- /dev/null +++ b/experiments/utils/utils.py @@ -0,0 +1,166 @@ +from typing import Any, Union +import os + +import matsciml # noqa: F401 +from matsciml.models.common import get_class_from_name +from matsciml.common.inspection import get_model_all_args + + +def verify_class_args(input_class, input_args): + all_args = get_model_all_args(input_class) + + for key in input_args: + assert ( + key in all_args + ), f"{key} was passed as a kwarg but does not match expected arguments." + + +def instantiate_arg_dict(input: Union[list, dict[str, Any]]) -> dict[str, Any]: + """Used to traverse through an config file and spin up any arguments that specify + a 'class_path' and optional 'init_args'. Replaces the string values with the + instantiated class. If the tag is a 'class_instance' this is simple a class which + has not been instantiated yet. + + Parameters + ---------- + input : dict[str, Any] + Input config dictionary. + + Returns + ------- + dict[str, Any] + Updated config dictionary with instantiated classes as necessary. + """ + if isinstance(input, dict): + for key, value in list(input.items()): + if key == "class_instance": + return get_class_from_name(value) + if key == "class_path": + class_path = value + transform_args = {} + input_args = input.get("init_args", {}) + if isinstance(input_args, list): + for input in input_args: + transform_args.update(input) + else: + transform_args = input_args + class_path = get_class_from_name(class_path) + verify_class_args(class_path, transform_args) + return class_path(**transform_args) + if key == "encoder_class": + input[key] = get_class_from_name(value["class_path"]) + elif isinstance(value, dict) and "class_path" in value: + class_path = value["class_path"] + class_path = get_class_from_name(class_path) + input_args = value.get("init_args", {}) + verify_class_args(class_path, input_args) + input[key] = class_path(**input_args) + else: + input[key] = instantiate_arg_dict(value) + elif isinstance(input, list): + for i, item in enumerate(input): + input[i] = instantiate_arg_dict(item) + return input + + +def setup_log_dir(config: dict[str, Any]) -> str: + model = config["model"] + datasets = "_".join(list(config["dataset"].keys())) + experiment_name = "_".join([model, datasets]) + if "log_dir" in config: + log_dir = os.path.join(config["log_dir"], experiment_name) + else: + log_dir = os.path.join("experiment_logs", experiment_name) + next_version = _get_next_version(log_dir) + log_dir = os.path.join(log_dir, next_version) + return log_dir + + +def _get_next_version(root_dir: str) -> str: + if not os.path.isdir(root_dir): + os.makedirs(root_dir) + + existing_versions = [] + for d in os.listdir(root_dir): + if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + existing_versions.append(int(d.split("_")[1])) + + if len(existing_versions) == 0: + return "version_0" + + return f"version_{max(existing_versions) + 1}" + + +def convert_string(input_str: str) -> Union[int, float, str]: + # not sure if there is a better way to do this + try: + return int(input_str) + except ValueError: + pass + try: + return float(input_str) + except ValueError: + pass + return input_str + + +def update_arg_dict( + dict_name: str, arg_dict: dict[str, Any], new_args: list[list[str]] +) -> dict[str, Any]: + """Update a config with arguments supplied from the cli. Can only update + to numeric or string values by dictionary keys. Lists such as callbacks, loggers, + or transforms are not updatable. + + Example: + + dict_name = "dataset" + arg_dict = {'debug': {'batch_size': 4, 'num_workers': 0}} + new_args = [['dataset', 'debug', 'batch_size', '20']] + + The input specifies that we are updating the arg_dict with new_args affecting the + 'dataset' config. + The dictionary keys to traverse through will be "debug" and "batch_size". + The target value to update to is '20', which will be converted to an int. + + + Parameters + ---------- + dict_name : str + Dictionary to be updated, (model, dataset, or trainer) + arg_dict : dict[str, Any] + Original dictionary + new_args : list[list[str]] + Lists of arguments used to specify dictionary to update, the arguments to + traverse through to update, and the value to update to. + + Returns + ------- + dict[str, Any] + New arg_dict with updated parameters. + """ + if new_args is None: + return arg_dict + updated_arg_dict = arg_dict + new_args = [arg_list for arg_list in new_args if dict_name in arg_list] + for new_arg in new_args: + value = new_arg[-1] + for key in new_arg[1:-1]: + if key not in updated_arg_dict: + updated_arg_dict[key] = {} + if key != new_arg[-2]: + updated_arg_dict = updated_arg_dict[key] + updated_arg_dict[key] = convert_string(value) + return arg_dict + + +def config_help() -> None: + from experiments.datasets import available_data + from experiments.models import available_models + + print("Models:") + _ = [print("\t", m) for m in available_models.keys() if m != "generic"] + print() + print("Datasets and Target Keys:") + for k, v in available_data.items(): + if k != "generic": + print(f"\t{k}: {v['target_keys']}") diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 5c65baed..4a3e168c 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -337,12 +337,17 @@ def read_batch(self, batch: BatchDict) -> DataDict: sizes = [] # loop over each sample within a batch for index, sample in enumerate(temp_pos): - src_nodes, dst_nodes = batch["src_nodes"][index], batch["dst_nodes"][index] + src_nodes, dst_nodes = ( + batch["pc_src_nodes"][index], + batch["pc_dst_nodes"][index], + ) # use dst_nodes to gauge size because you will always have more # dst nodes than src nodes right now sizes.append(len(dst_nodes)) # carve out neighborhoods as dictated by the dataset/transform definition - sample_pc_pos = sample[src_nodes][None, :] - sample[dst_nodes][:, None] + sample_pc_pos = ( + sample[src_nodes.long()][None, :] - sample[dst_nodes.long()][:, None] + ) pc_pos.append(sample_pc_pos) # pad the position result pc_pos, mask = pad_point_cloud(pc_pos, max(sizes)) diff --git a/matsciml/models/dgl/gaanet/tests/test_gala.py b/matsciml/models/dgl/gaanet/tests/test_gala.py index 8f9fe095..b0550b26 100644 --- a/matsciml/models/dgl/gaanet/tests/test_gala.py +++ b/matsciml/models/dgl/gaanet/tests/test_gala.py @@ -42,8 +42,8 @@ def data(): ), # this should be 3 point clouds "pc_features": torch.rand(3, 4, 4, 200), - "src_nodes": [torch.arange(3), torch.arange(3), torch.arange(4)], - "dst_nodes": [torch.arange(3), torch.arange(3), torch.arange(4)], + "pc_src_nodes": [torch.arange(3), torch.arange(3), torch.arange(4)], + "pc_dst_nodes": [torch.arange(3), torch.arange(3), torch.arange(4)], "sizes": [3, 3, 4], } return data diff --git a/pyproject.toml b/pyproject.toml index 8cc6a83a..2cd71889 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ symmetry = [ ] [tool.setuptools.packages.find] -include = ["matsciml*"] +include = ["matsciml*", "experiments*"] where = ["."] [tool.setuptools.dynamic]