Skip to content

Commit

Permalink
Merge pull request IntelLabs#212 from laserkelvin/lightning-cli-revival
Browse files Browse the repository at this point in the history
Lightning CLI revival
  • Loading branch information
laserkelvin authored May 13, 2024
2 parents 65661cd + 608fd1f commit e7f9577
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 56 deletions.
15 changes: 15 additions & 0 deletions examples/lightning_cli/data.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
data:
class_path: matsciml.lightning.data_utils.MatSciMLDataModule
init_args:
dataset: "MaterialsProjectDataset"
train_path: null # this needs to be changed to match your dataset path!
dset_kwargs:
transforms:
- class_path: matsciml.datasets.transforms.PeriodicPropertiesTransform
init_args:
cutoff_radius: 6.0
adaptive_cutoff: true
- class_path: matsciml.datasets.transforms.PointCloudToGraphTransform
init_args:
backend: "pyg"
node_keys: ["pos", "atomic_numbers"]
9 changes: 9 additions & 0 deletions examples/lightning_cli/egnn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model:
class_path: matsciml.models.base.ScalarRegressionTask
init_args:
encoder_class: matsciml.models.pyg.EGNN
encoder_kwargs:
hidden_dim: 128
output_dim: 64
task_keys: # this matches the dataset
- "band_gap"
10 changes: 10 additions & 0 deletions examples/lightning_cli/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
# This uses the matsciml Lightning CLI wrapper to configure
# the training workflow. The advantage of doing so is the
# ability to modularize experiments, i.e. not have to redefine
# datasets, models, and/or trainer control.
python -m matsciml.lightning.cli fit \
--config data.yml \
--config egnn.yml \
--config trainer.yml \
--trainer.fast_dev_run 20 # override some config if we need to
16 changes: 16 additions & 0 deletions examples/lightning_cli/trainer.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# pytorch_lightning==2.2.4
seed_everything: true
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: null
logger: null
callbacks: null
fast_dev_run: 10
max_epochs: null
min_epochs: null
gradient_clip_val: null
gradient_clip_algorithm: null
inference_mode: false
41 changes: 4 additions & 37 deletions matsciml/lightning/cli.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,11 @@
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022-4 Intel Corporation
# SPDX-License-Identifier: MIT License
from __future__ import annotations

import pytorch_lightning as pl
from pytorch_lightning.utilities.cli import (
DATAMODULE_REGISTRY,
MODEL_REGISTRY,
LightningCLI,
)
from pytorch_lightning.cli import LightningCLI

from matsciml import models
from matsciml.lightning import data_utils

"""
This module interfaces with the PyTorch Lightning CLI, and when called, allows
the user to define YAML configuration files for reproducible and modular
training, testing, and development.
All that is really done in this module is inform the PyTorch Lightning registry
where to look for models (`MODEL_REGISTRY`) and data (`DATAMODULE_REGISTRY`).
The former is set up to look for children of `LightningModule`, which comprises
the task `LightningModule`s, and models like DimeNetPP that are implemented with
`AbstractTask`, which in turn also inherits from `LightningModule` (this might change later).
To check what tasks and data modules have been successfully registered, import this
module, and print `MODEL_REGISTRY` and/or `DATAMODULE_REGISTRY`: if your model was
included in the namespace correctly, it should appear there.
To use the CLI, all one needs to do is write a YAML configuration file, and then
run `python -m matsciml.lightning.cli fit --config <CONFIG>.yml`, substituting
fit with any other appropriate task, and <CONFIG> with the name of your configuration
file.
"""


# this registers the task classes implemented as LightningModules
MODEL_REGISTRY.register_classes(models, pl.LightningModule)

# this registers the data
DATAMODULE_REGISTRY.register_classes(data_utils, pl.LightningDataModule)
from matsciml.models import *
from matsciml.lightning import data_utils # noqa: F401


if __name__ == "__main__":
Expand Down
18 changes: 0 additions & 18 deletions matsciml/lightning/registry.py

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"lmdb==1.4.1",
"geometric-algebra-attention>=0.3.0",
"sympy>=1.10.1",
"jsonargparse[signatures]>=4.13.1",
"jsonargparse[signatures]>=4.27.7",
"mp_api==0.41.2",
"emmet-core==0.83.6",
"pydantic==2.7.1",
Expand Down

0 comments on commit e7f9577

Please sign in to comment.