forked from IntelLabs/matsciml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request IntelLabs#191 from laserkelvin/noisy-displacement-…
…pretraining Noisy node positions pretraining task
- Loading branch information
Showing
8 changed files
with
479 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
|
||
import pytorch_lightning as pl | ||
|
||
from matsciml.datasets.transforms import ( | ||
PointCloudToGraphTransform, | ||
PeriodicPropertiesTransform, | ||
NoisyPositions, | ||
) | ||
from matsciml.lightning.data_utils import MatSciMLDataModule | ||
from matsciml.models.base import NodeDenoisingTask | ||
from matsciml.models.pyg import EGNN | ||
|
||
""" | ||
This example script shows EGNN being used for a denoising | ||
pretraining task, as described in: | ||
Pre-training via denoising for molecular property prediction | ||
by Zaidi _et al._, ICLR 2023; https://openreview.net/pdf?id=tYIMtogyee | ||
The permutation of transforms is not fully invariant, based on how | ||
it is currently implemented. This configuration is the recommended one, | ||
where positions are noised _after_ generating the periodic properties; | ||
this is to ensure that the periodic offsets are generated based on | ||
the noise-free positions. | ||
""" | ||
|
||
# construct IS2RE relaxed energy regression with PyG implementation of E(n)-GNN | ||
task = NodeDenoisingTask( | ||
encoder_class=EGNN, | ||
encoder_kwargs={"hidden_dim": 128, "output_dim": 64}, | ||
) | ||
# set up the data module | ||
dm = MatSciMLDataModule.from_devset( | ||
"AlexandriaDataset", | ||
dset_kwargs={ | ||
"transforms": [ | ||
PeriodicPropertiesTransform(6.0, True), | ||
NoisyPositions( | ||
scale=1e-3 | ||
), # this sets the scale of the Gaussian noise added | ||
PointCloudToGraphTransform( | ||
"pyg", | ||
node_keys=[ | ||
"pos", | ||
"noisy_pos", | ||
"atomic_numbers", | ||
], # ensure noisy_pos is included for the task | ||
), | ||
], | ||
}, | ||
) | ||
|
||
# run a quick training loop | ||
trainer = pl.Trainer(fast_dev_run=10) | ||
trainer.fit(task, datamodule=dm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from __future__ import annotations | ||
|
||
|
||
from matsciml.datasets.transforms.pretraining.noisy_positions import NoisyPositions | ||
|
||
|
||
__all__ = ["NoisyPositions"] |
66 changes: 66 additions & 0 deletions
66
matsciml/datasets/transforms/pretraining/noisy_positions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
from dgl import DGLGraph | ||
|
||
from matsciml.common.types import DataDict | ||
from matsciml.datasets.transforms.base import AbstractDataTransform | ||
|
||
|
||
__all__ = ["NoisyPositions"] | ||
|
||
|
||
class NoisyPositions(AbstractDataTransform): | ||
def __init__(self, scale: float = 1e-3) -> None: | ||
""" | ||
Initializes a NoisyPositions transform. | ||
This class generates i.i.d. Gaussian displacements to atom | ||
coordinates, and adds a new ``noisy_pos`` key to the data | ||
sample. While there is no prescribed ordering of transforms, | ||
if graphs are available in the sample the transform will | ||
act on the graph over the raw point cloud positions. If | ||
your pipeline involves graph creation, note that this _could_ | ||
affect the resulting edges produced, depending on the scale of | ||
noise used. | ||
Implemented from the strategy described by Zaidi _et al._ 2023, | ||
https://openreview.net/pdf?id=tYIMtogyee | ||
Parameters | ||
---------- | ||
scale : float | ||
Scale used to multiply N~(0, I_3) Gaussian noise | ||
""" | ||
super().__init__() | ||
self.scale = scale | ||
|
||
def __call__(self, data: DataDict) -> DataDict: | ||
if "graph" in data: | ||
graph = data["graph"] | ||
if isinstance(graph, DGLGraph): | ||
pos = graph.ndata["pos"] | ||
else: | ||
# assume it's a PyG graph, grab as attribute | ||
pos = graph.pos | ||
else: | ||
# otherwise it's a point cloud | ||
pos = data["pos"] | ||
noise = torch.randn_like(pos) * self.scale | ||
noisy_pos = pos + noise | ||
# write the noisy node data; same logic as before | ||
if "graph" in data: | ||
graph = data["graph"] | ||
if isinstance(graph, DGLGraph): | ||
graph.ndata["noisy_pos"] = noisy_pos | ||
else: | ||
setattr(graph, "noisy_pos", noisy_pos) | ||
else: | ||
data["noisy_pos"] = noisy_pos | ||
# set targets so that tasks know what to do | ||
data["targets"]["denoise"] = noise | ||
if "pretraining" in data["target_types"]: | ||
data["target_types"]["pretraining"].append("denoise") | ||
else: | ||
data["target_types"]["pretraining"] = ["denoise"] | ||
return data |
Empty file.
110 changes: 110 additions & 0 deletions
110
matsciml/datasets/transforms/pretraining/tests/test_noisy_positions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
import torch | ||
from dgl import DGLGraph | ||
|
||
from matsciml.common.registry import registry | ||
from matsciml.lightning import MatSciMLDataModule | ||
from matsciml.datasets.transforms.pretraining import NoisyPositions | ||
from matsciml.datasets.transforms import ( | ||
PeriodicPropertiesTransform, | ||
PointCloudToGraphTransform, | ||
) | ||
|
||
|
||
dset_names = registry.__entries__["datasets"].keys() | ||
valid_dsets = list( | ||
filter( | ||
lambda x: all( | ||
[match not in x for match in ["PyG", "Multi", "Cdvae", "PointGroup"]] | ||
), | ||
dset_names, | ||
) | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("dset_name", valid_dsets) | ||
def test_noisy_pointcloud(dset_name): | ||
"""Test the transform on the raw point clouds""" | ||
dset_class = registry.get_dataset_class(dset_name) | ||
dset = dset_class.from_devset(transforms=[NoisyPositions()]) | ||
for index in range(10): | ||
sample = dset.__getitem__(index) | ||
assert "noisy_pos" in sample | ||
assert torch.isfinite(sample["noisy_pos"]).all() | ||
assert "pretraining" in sample["target_types"] | ||
assert "denoise" in sample["target_types"]["pretraining"] | ||
|
||
|
||
@pytest.mark.parametrize("dset_name", valid_dsets) | ||
@pytest.mark.parametrize("graph_type", ["pyg", "dgl"]) | ||
def test_noisy_graph(dset_name, graph_type): | ||
"""Test the transform on graph types.""" | ||
dset_class = registry.get_dataset_class(dset_name) | ||
dset = dset_class.from_devset( | ||
transforms=[ | ||
PeriodicPropertiesTransform(6.0, adaptive_cutoff=True), | ||
NoisyPositions(), | ||
PointCloudToGraphTransform( | ||
graph_type, node_keys=["atomic_numbers", "pos", "noisy_pos"] | ||
), | ||
] | ||
) | ||
for index in range(10): | ||
sample = dset.__getitem__(index) | ||
graph = sample["graph"] | ||
if isinstance(graph, DGLGraph): | ||
target = graph.ndata | ||
else: | ||
target = graph | ||
assert "noisy_pos" in target | ||
assert torch.isfinite(target["noisy_pos"]).all() | ||
assert "pretraining" in sample["target_types"] | ||
assert "denoise" in sample["target_types"]["pretraining"] | ||
|
||
|
||
@pytest.mark.parametrize("dset_name", valid_dsets) | ||
def test_noisy_pointcloud_datamodule(dset_name): | ||
"""Test the transform on point cloud types with batching.""" | ||
dm = MatSciMLDataModule.from_devset( | ||
dset_name, batch_size=4, dset_kwargs={"transforms": [NoisyPositions()]} | ||
) | ||
dm.setup("fit") | ||
loader = dm.train_dataloader() | ||
batch = next(iter(loader)) | ||
assert "noisy_pos" in batch | ||
assert torch.isfinite(batch["noisy_pos"]).all() | ||
assert "pretraining" in batch["target_types"] | ||
assert "denoise" in batch["target_types"]["pretraining"] | ||
|
||
|
||
@pytest.mark.parametrize("dset_name", valid_dsets) | ||
@pytest.mark.parametrize("graph_type", ["pyg", "dgl"]) | ||
def test_noisy_graph_datamodule(dset_name, graph_type): | ||
"""Test the transform on graph types with batching.""" | ||
dm = MatSciMLDataModule.from_devset( | ||
dset_name, | ||
dset_kwargs=dict( | ||
transforms=[ | ||
PeriodicPropertiesTransform(6.0, adaptive_cutoff=True), | ||
NoisyPositions(), | ||
PointCloudToGraphTransform( | ||
graph_type, node_keys=["atomic_numbers", "pos", "noisy_pos"] | ||
), | ||
], | ||
), | ||
batch_size=4, | ||
) | ||
dm.setup("fit") | ||
loader = dm.train_dataloader() | ||
batch = next(iter(loader)) | ||
graph = batch["graph"] | ||
if isinstance(graph, DGLGraph): | ||
target = graph.ndata | ||
else: | ||
target = graph | ||
assert "noisy_pos" in target | ||
assert torch.isfinite(target["noisy_pos"]).all() | ||
assert "pretraining" in batch["target_types"] | ||
assert "denoise" in batch["target_types"]["pretraining"] |
Oops, something went wrong.