Skip to content

Commit

Permalink
Merge pull request IntelLabs#191 from laserkelvin/noisy-displacement-…
Browse files Browse the repository at this point in the history
…pretraining

Noisy node positions pretraining task
  • Loading branch information
laserkelvin authored Apr 23, 2024
2 parents 6f0a9e1 + 01bc048 commit 79b8815
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 95 deletions.
57 changes: 57 additions & 0 deletions examples/tasks/pretraining/denoising_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import pytorch_lightning as pl

from matsciml.datasets.transforms import (
PointCloudToGraphTransform,
PeriodicPropertiesTransform,
NoisyPositions,
)
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models.base import NodeDenoisingTask
from matsciml.models.pyg import EGNN

"""
This example script shows EGNN being used for a denoising
pretraining task, as described in:
Pre-training via denoising for molecular property prediction
by Zaidi _et al._, ICLR 2023; https://openreview.net/pdf?id=tYIMtogyee
The permutation of transforms is not fully invariant, based on how
it is currently implemented. This configuration is the recommended one,
where positions are noised _after_ generating the periodic properties;
this is to ensure that the periodic offsets are generated based on
the noise-free positions.
"""

# construct IS2RE relaxed energy regression with PyG implementation of E(n)-GNN
task = NodeDenoisingTask(
encoder_class=EGNN,
encoder_kwargs={"hidden_dim": 128, "output_dim": 64},
)
# set up the data module
dm = MatSciMLDataModule.from_devset(
"AlexandriaDataset",
dset_kwargs={
"transforms": [
PeriodicPropertiesTransform(6.0, True),
NoisyPositions(
scale=1e-3
), # this sets the scale of the Gaussian noise added
PointCloudToGraphTransform(
"pyg",
node_keys=[
"pos",
"noisy_pos",
"atomic_numbers",
], # ensure noisy_pos is included for the task
),
],
},
)

# run a quick training loop
trainer = pl.Trainer(fast_dev_run=10)
trainer.fit(task, datamodule=dm)
1 change: 1 addition & 0 deletions matsciml/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from matsciml.datasets.transforms.representations import *
from matsciml.datasets.transforms.matgl_datasets import *
from matsciml.datasets.transforms.frame_averaging import *
from matsciml.datasets.transforms.pretraining import *
7 changes: 7 additions & 0 deletions matsciml/datasets/transforms/pretraining/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations


from matsciml.datasets.transforms.pretraining.noisy_positions import NoisyPositions


__all__ = ["NoisyPositions"]
66 changes: 66 additions & 0 deletions matsciml/datasets/transforms/pretraining/noisy_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

import torch
from dgl import DGLGraph

from matsciml.common.types import DataDict
from matsciml.datasets.transforms.base import AbstractDataTransform


__all__ = ["NoisyPositions"]


class NoisyPositions(AbstractDataTransform):
def __init__(self, scale: float = 1e-3) -> None:
"""
Initializes a NoisyPositions transform.
This class generates i.i.d. Gaussian displacements to atom
coordinates, and adds a new ``noisy_pos`` key to the data
sample. While there is no prescribed ordering of transforms,
if graphs are available in the sample the transform will
act on the graph over the raw point cloud positions. If
your pipeline involves graph creation, note that this _could_
affect the resulting edges produced, depending on the scale of
noise used.
Implemented from the strategy described by Zaidi _et al._ 2023,
https://openreview.net/pdf?id=tYIMtogyee
Parameters
----------
scale : float
Scale used to multiply N~(0, I_3) Gaussian noise
"""
super().__init__()
self.scale = scale

def __call__(self, data: DataDict) -> DataDict:
if "graph" in data:
graph = data["graph"]
if isinstance(graph, DGLGraph):
pos = graph.ndata["pos"]
else:
# assume it's a PyG graph, grab as attribute
pos = graph.pos
else:
# otherwise it's a point cloud
pos = data["pos"]
noise = torch.randn_like(pos) * self.scale
noisy_pos = pos + noise
# write the noisy node data; same logic as before
if "graph" in data:
graph = data["graph"]
if isinstance(graph, DGLGraph):
graph.ndata["noisy_pos"] = noisy_pos
else:
setattr(graph, "noisy_pos", noisy_pos)
else:
data["noisy_pos"] = noisy_pos
# set targets so that tasks know what to do
data["targets"]["denoise"] = noise
if "pretraining" in data["target_types"]:
data["target_types"]["pretraining"].append("denoise")
else:
data["target_types"]["pretraining"] = ["denoise"]
return data
Empty file.
110 changes: 110 additions & 0 deletions matsciml/datasets/transforms/pretraining/tests/test_noisy_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

import pytest
import torch
from dgl import DGLGraph

from matsciml.common.registry import registry
from matsciml.lightning import MatSciMLDataModule
from matsciml.datasets.transforms.pretraining import NoisyPositions
from matsciml.datasets.transforms import (
PeriodicPropertiesTransform,
PointCloudToGraphTransform,
)


dset_names = registry.__entries__["datasets"].keys()
valid_dsets = list(
filter(
lambda x: all(
[match not in x for match in ["PyG", "Multi", "Cdvae", "PointGroup"]]
),
dset_names,
)
)


@pytest.mark.parametrize("dset_name", valid_dsets)
def test_noisy_pointcloud(dset_name):
"""Test the transform on the raw point clouds"""
dset_class = registry.get_dataset_class(dset_name)
dset = dset_class.from_devset(transforms=[NoisyPositions()])
for index in range(10):
sample = dset.__getitem__(index)
assert "noisy_pos" in sample
assert torch.isfinite(sample["noisy_pos"]).all()
assert "pretraining" in sample["target_types"]
assert "denoise" in sample["target_types"]["pretraining"]


@pytest.mark.parametrize("dset_name", valid_dsets)
@pytest.mark.parametrize("graph_type", ["pyg", "dgl"])
def test_noisy_graph(dset_name, graph_type):
"""Test the transform on graph types."""
dset_class = registry.get_dataset_class(dset_name)
dset = dset_class.from_devset(
transforms=[
PeriodicPropertiesTransform(6.0, adaptive_cutoff=True),
NoisyPositions(),
PointCloudToGraphTransform(
graph_type, node_keys=["atomic_numbers", "pos", "noisy_pos"]
),
]
)
for index in range(10):
sample = dset.__getitem__(index)
graph = sample["graph"]
if isinstance(graph, DGLGraph):
target = graph.ndata
else:
target = graph
assert "noisy_pos" in target
assert torch.isfinite(target["noisy_pos"]).all()
assert "pretraining" in sample["target_types"]
assert "denoise" in sample["target_types"]["pretraining"]


@pytest.mark.parametrize("dset_name", valid_dsets)
def test_noisy_pointcloud_datamodule(dset_name):
"""Test the transform on point cloud types with batching."""
dm = MatSciMLDataModule.from_devset(
dset_name, batch_size=4, dset_kwargs={"transforms": [NoisyPositions()]}
)
dm.setup("fit")
loader = dm.train_dataloader()
batch = next(iter(loader))
assert "noisy_pos" in batch
assert torch.isfinite(batch["noisy_pos"]).all()
assert "pretraining" in batch["target_types"]
assert "denoise" in batch["target_types"]["pretraining"]


@pytest.mark.parametrize("dset_name", valid_dsets)
@pytest.mark.parametrize("graph_type", ["pyg", "dgl"])
def test_noisy_graph_datamodule(dset_name, graph_type):
"""Test the transform on graph types with batching."""
dm = MatSciMLDataModule.from_devset(
dset_name,
dset_kwargs=dict(
transforms=[
PeriodicPropertiesTransform(6.0, adaptive_cutoff=True),
NoisyPositions(),
PointCloudToGraphTransform(
graph_type, node_keys=["atomic_numbers", "pos", "noisy_pos"]
),
],
),
batch_size=4,
)
dm.setup("fit")
loader = dm.train_dataloader()
batch = next(iter(loader))
graph = batch["graph"]
if isinstance(graph, DGLGraph):
target = graph.ndata
else:
target = graph
assert "noisy_pos" in target
assert torch.isfinite(target["noisy_pos"]).all()
assert "pretraining" in batch["target_types"]
assert "denoise" in batch["target_types"]["pretraining"]
Loading

0 comments on commit 79b8815

Please sign in to comment.