Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quality of life and helper callback functions #237

Merged
merged 32 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7abff49
refactor: ruff fixes and adding fractional coordinate check
laserkelvin Jun 6, 2024
0e4e8bc
feat: added embedding forward hook check
laserkelvin Jun 6, 2024
a433685
feat: added encoder forward hook to helper
laserkelvin Jun 6, 2024
7759f09
feat: added encoder-outputhead compaison
laserkelvin Jun 6, 2024
7221eff
feat: working grad norm logging
laserkelvin Jun 6, 2024
17297f2
refactor: changing variance value to a much smaller value
laserkelvin Jun 6, 2024
4372d75
docs: adding docstrings throughout helper
laserkelvin Jun 6, 2024
1a84d16
feat: added function to log embeddings
laserkelvin Jun 6, 2024
552cf1b
feat: adding embedding logging call
laserkelvin Jun 6, 2024
9922415
refactor: using hparams for log embedding kwarg
laserkelvin Jun 6, 2024
2349305
fix: adding global step specification in tensorboard embedding log
laserkelvin Jun 6, 2024
1a749be
refactor: adding global step to add embedding
laserkelvin Jun 6, 2024
cac78fe
refactor: making forward generically stash embeddings
laserkelvin Jun 7, 2024
53d640b
refactor: putting embedding logging in train step
laserkelvin Jun 7, 2024
d7e1f07
refactor: adding embedding logging to independent steps
laserkelvin Jun 7, 2024
997a1be
chore: rebasing main to finalize PR
laserkelvin Jul 1, 2024
1640ef8
refactor: added log embedding frequency control
laserkelvin Jun 7, 2024
a9dd125
refactor: cleaned up log_embeddings function
laserkelvin Jun 7, 2024
7186df6
test: added simple unit test for forward hook
laserkelvin Jun 7, 2024
44150f5
feat: implemented working autocorrelation callback
laserkelvin Jun 7, 2024
8eb876f
docs: added a variety of docstrings for model autocorrelation
laserkelvin Jun 7, 2024
9c48a5c
docs: added docstring for helper callback
laserkelvin Jun 7, 2024
a8ccae8
scripts: added scripts to demonstrate callback usage
laserkelvin Jun 7, 2024
28e1a33
refactor: using cartesian coordinates as regular inputs
laserkelvin Jun 7, 2024
af745c0
refactor: taking absolute value of the median for comparison
laserkelvin Jun 12, 2024
515e4b8
refactor: now looping over multiple loggers, if any are supplied
laserkelvin Jun 12, 2024
7f726b9
refactor: making forward pass only set embeddings to batch, not reusi…
laserkelvin Jun 12, 2024
b888ee8
fix: nesting scheduler stepping mechanism only if something is passed
laserkelvin Jul 1, 2024
c182d90
refactor: allowing multi task subtasks to reuse shared embedding
laserkelvin Jul 1, 2024
2ca1367
refactor: adding log embeddings kwargs to multi task litmodule
laserkelvin Jul 1, 2024
43175b4
feat: added log embeddings method to multitask litmodule
laserkelvin Jul 1, 2024
d53cde4
refactor: logging embeddings for both training and validation steps i…
laserkelvin Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions examples/callbacks/autocorrelation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from matsciml.datasets.transforms import DistancesTransform, PointCloudToGraphTransform
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.lightning.callbacks import ModelAutocorrelation
from matsciml.models import SchNet
from matsciml.models.base import ScalarRegressionTask

"""
This script demonstrates the use of the `ModelAutocorrelation` callback.

The main utility of this callback is to monitor the degree of correlation
in model parameters and optionally gradients over a time span. The idea
is that for optimization trajectories, steps are ideally as de-correlated
as possible (at least within reason), and indeed is actually a major
assumption of Adam-like optimizers.

There is no hard coded heuristic for identifying "too much correlation"
yet, however this callback can help do the data collection for you to
develop a sense for yourself. One method for trying this out is to
set varying learning rates, and seeing how the autocorrelation spectra
are different.
"""

# construct a scalar regression task with SchNet encoder
task = ScalarRegressionTask(
encoder_class=SchNet,
# kwargs to be passed into the creation of SchNet model
encoder_kwargs={
"encoder_only": True,
"hidden_feats": [128, 128, 128],
"atom_embedding_dim": 128,
},
# which keys to use as targets
task_keys=["energy_relaxed"],
log_embeddings=False,
)
# Use IS2RE devset to test workflow
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
dset_kwargs={
"transforms": [
PointCloudToGraphTransform(
"dgl",
cutoff_dist=20.0,
node_keys=["pos", "atomic_numbers"],
),
DistancesTransform(),
],
},
)

# tensorboard logging if working purely locally, otherwise wandb
logger = WandbLogger(
melo-gonzo marked this conversation as resolved.
Show resolved Hide resolved
name="helper-callback", offline=False, project="matsciml", log_model="all"
)
logger = TensorBoardLogger("./")

# run a quick training loop
trainer = pl.Trainer(max_epochs=30, logger=logger, callbacks=[ModelAutocorrelation()])
trainer.fit(task, datamodule=dm)
57 changes: 57 additions & 0 deletions examples/callbacks/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

from matsciml.datasets.transforms import DistancesTransform, PointCloudToGraphTransform
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.lightning.callbacks import TrainingHelperCallback
from matsciml.models import SchNet
from matsciml.models.base import ScalarRegressionTask

"""
This script demonstrates the use of the ``TrainingHelperCallback``
callback. The purpose of this callback is to provide some
helpful heuristics into the training process by identifying
some common issues like unused weights, small gradients,
and oversmoothed embeddings.
"""

# construct a scalar regression task with SchNet encoder
task = ScalarRegressionTask(
encoder_class=SchNet,
# kwargs to be passed into the creation of SchNet model
encoder_kwargs={
"encoder_only": True,
"hidden_feats": [128, 128, 128],
"atom_embedding_dim": 128,
},
# which keys to use as targets
task_keys=["energy_relaxed"],
log_embeddings=True,
)
# Use IS2RE devset to test workflow
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
dset_kwargs={
"transforms": [
PointCloudToGraphTransform(
"dgl",
cutoff_dist=20.0,
node_keys=["pos", "atomic_numbers"],
),
DistancesTransform(),
],
},
)

# tensorboard logging if working purely locally
# logger = TensorBoardLogger("./")
logger = WandbLogger(
name="helper-callback", offline=False, project="matsciml", log_model="all"
)

# run a quick training loop
trainer = pl.Trainer(max_epochs=10, logger=logger, callbacks=[TrainingHelperCallback()])
trainer.fit(task, datamodule=dm)
15 changes: 11 additions & 4 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import lmdb
import torch
import numpy as np
from einops import einsum, rearrange
from joblib import Parallel, delayed
from pymatgen.core import Lattice, Structure
Expand Down Expand Up @@ -302,11 +303,11 @@ def get_lmdb_keys(
"Both `ignore_keys` and `_lambda` were passed; arguments are mutually exclusive.",
)
if ignore_keys:
_lambda = lambda x: x not in ignore_keys
_lambda = lambda x: x not in ignore_keys # noqa: E731
else:
if not _lambda:
# escape case where we basically don't filter
_lambda = lambda x: x
_lambda = lambda x: x # noqa: E731
# convert to a sorted list of keys
keys = sorted(list(filter(_lambda, keys)))
return keys
Expand Down Expand Up @@ -529,7 +530,7 @@ def divide_data_chunks(
assert all(
[length != 0 for length in lengths],
), "Too many processes specified and not enough data to split over multiple LMDB files. Decrease `num_procs!`"
p = Parallel(num_procs)(
_ = Parallel(num_procs)(
delayed(write_chunk)(chunk, target_dir, index, metadata)
for chunk, index in zip(chunks, lmdb_indices)
)
Expand Down Expand Up @@ -693,6 +694,11 @@ def calculate_periodic_shifts(
include_index=True,
include_image=True,
)
# check to make sure the cell definition is valid
if np.any(structure.frac_coords > 1.0):
raise ValueError(
f"Structure has fractional coordinates greater than 1! Check structure:\n{structure}"
)

def _all_sites_have_neighbors(neighbors):
return all([len(n) for n in neighbors])
Expand Down Expand Up @@ -729,12 +735,13 @@ def _all_sites_have_neighbors(neighbors):
cell = torch.from_numpy(cell.copy()).float()
# get coordinates as well, for standardization
frac_coords = torch.from_numpy(structure.frac_coords).float()
coords = torch.from_numpy(structure.cart_coords).float()
return_dict = {
"src_nodes": torch.LongTensor(all_src),
"dst_nodes": torch.LongTensor(all_dst),
"images": torch.FloatTensor(all_images),
"cell": cell,
"pos": frac_coords,
"pos": coords,
}
# now calculate offsets based on each image for a lattice
return_dict["offsets"] = einsum(return_dict["images"], cell, "v i, n i j -> v j")
Expand Down
Loading
Loading