From ab2cbcdb5b49efc8291137973349c900e77ed08e Mon Sep 17 00:00:00 2001 From: "Gonzales, Carmelo" Date: Mon, 21 Oct 2024 10:49:58 -0700 Subject: [PATCH 1/4] feat: adding ase backend for computing periodic boundary conditions --- matsciml/datasets/tests/test_transforms.py | 20 ++++++- matsciml/datasets/transforms/pbc.py | 27 +++++++-- matsciml/datasets/utils.py | 69 ++++++++++++++++++++++ 3 files changed, 110 insertions(+), 6 deletions(-) diff --git a/matsciml/datasets/tests/test_transforms.py b/matsciml/datasets/tests/test_transforms.py index b17b8faa..6f0c0302 100644 --- a/matsciml/datasets/tests/test_transforms.py +++ b/matsciml/datasets/tests/test_transforms.py @@ -177,5 +177,23 @@ def test_graph_sorting(): ) dm.setup() loader = dm.train_dataloader() - graph = next(iter(loader))["graph"] + _ = next(iter(loader))["graph"] # not really anything to test, but just make sure it runs :D + + +@pytest.mark.parametrize("backend", ["ase", "pymatgen"]) +def test_ase_periodic(backend): + trans = [ + transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend=backend + ) + ] + dm = MatSciMLDataModule.from_devset( + "S2EFDataset", + dset_kwargs={"transforms": trans}, + ) + dm.setup() + loader = dm.train_dataloader() + batch = next(iter(loader)) + # check if periodic properties transform was applied + assert "unit_offsets" in batch diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index b8f9de0b..5d9ed271 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -1,13 +1,14 @@ from __future__ import annotations -import torch import numpy as np +import torch from pymatgen.core import Lattice, Structure from matsciml.common.types import DataDict from matsciml.datasets.transforms.base import AbstractDataTransform from matsciml.datasets.utils import ( calculate_periodic_shifts, + calculate_ase_periodic_shifts, make_pymatgen_periodic_structure, ) @@ -30,10 +31,16 @@ class PeriodicPropertiesTransform(AbstractDataTransform): a large cut off for the entire dataset. """ - def __init__(self, cutoff_radius: float, adaptive_cutoff: bool = False) -> None: + def __init__( + self, + cutoff_radius: float, + adaptive_cutoff: bool = False, + backend: str = "pymatgen", + ) -> None: super().__init__() self.cutoff_radius = cutoff_radius self.adaptive_cutoff = adaptive_cutoff + self.backend = backend def __call__(self, data: DataDict) -> DataDict: """ @@ -107,13 +114,23 @@ def __call__(self, data: DataDict) -> DataDict: tuple(angle * (180.0 / torch.pi) for angle in angles), ) lattice = Lattice.from_parameters(*abc, *angles, vesta=True) + # We need cell in data for ase backend. + data["cell"] = torch.tensor(lattice.matrix).unsqueeze(0).float() + structure = make_pymatgen_periodic_structure( data["atomic_numbers"], data["pos"], lattice=lattice, ) - graph_props = calculate_periodic_shifts( - structure, self.cutoff_radius, self.adaptive_cutoff - ) + if self.backend == "pymatgen": + graph_props = calculate_periodic_shifts( + structure, self.cutoff_radius, self.adaptive_cutoff + ) + elif self.backend == "ase": + graph_props = calculate_ase_periodic_shifts( + data, self.cutoff_radius, self.adaptive_cutoff + ) + else: + raise RuntimeError(f"Requested backend f{self.backend} not available.") data.update(graph_props) return data diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index ba9986e3..c7090004 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -1,8 +1,10 @@ from __future__ import annotations import pickle +import ase from collections.abc import Generator from functools import lru_cache, partial +from ase.neighborlist import NeighborList from os import makedirs from pathlib import Path from typing import Any, Callable @@ -719,6 +721,7 @@ def _all_sites_have_neighbors(neighbors): f"No neighbors detected for structure with cutoff {cutoff}; {structure}" ) # process the neighbors now + all_src, all_dst, all_images = [], [], [] for src_idx, dst_sites in enumerate(neighbors): for site in dst_sites: @@ -751,3 +754,69 @@ def _all_sites_have_neighbors(neighbors): frac_coords[dst] - frac_coords[src] + return_dict["offsets"] ) return return_dict + + +def calculate_ase_periodic_shifts(data, cutoff_radius, adaptive_cutoff): + cell = data["cell"] + + atoms = ase.Atoms( + positions=data["pos"], + numbers=data["atomic_numbers"], + cell=cell.squeeze(0), + pbc=(True, True, True), + ) + cutoff = [cutoff_radius] * atoms.positions.shape[0] + # Create a neighbor list + nl = NeighborList(cutoff, skin=0.0, self_interaction=False, bothways=True) + nl.update(atoms) + + neighbors = nl.nl.neighbors + + def _all_sites_have_neighbors(neighbors): + return all([len(n) for n in neighbors]) + + # if there are sites without neighbors and user requested adaptive + # cut off, we'll keep trying + if not _all_sites_have_neighbors(neighbors) and adaptive_cutoff: + while not _all_sites_have_neighbors(neighbors) and cutoff < 30.0: + # increment radial cutoff progressively + cutoff_radius += 0.5 + cutoff = [cutoff_radius] * atoms.positions.shape[0] + nl = NeighborList(cutoff, skin=0.0, self_interaction=False, bothways=True) + nl.update(atoms) + + # and we still don't find a neighbor, we have a problem with the structure + if not _all_sites_have_neighbors(neighbors): + raise ValueError(f"No neighbors detected for structure with cutoff {cutoff}") + + all_src, all_dst, all_images = [], [], [] + for src_idx in range(len(atoms)): + dst_index, image = nl.get_neighbors(src_idx) + for index in range(len(dst_index)): + all_src.append(src_idx) + all_dst.append(dst_index[index]) + all_images.append(image[index]) + + if any([len(obj) == 0 for obj in [all_images, all_dst, all_images]]): + raise ValueError( + f"No images or edges to work off for cutoff {cutoff}." + f" Please inspect your atoms object and neighbors: {atoms}." + ) + + frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float() + coords = torch.from_numpy(atoms.positions).float() + + return_dict = { + "src_nodes": torch.LongTensor(all_src), + "dst_nodes": torch.LongTensor(all_dst), + "images": torch.FloatTensor(all_images), + "cell": cell, + "pos": coords, + } + + return_dict["offsets"] = einsum(return_dict["images"], cell, "v i, n i j -> v j") + src, dst = return_dict["src_nodes"], return_dict["dst_nodes"] + return_dict["unit_offsets"] = ( + frac_coords[dst] - frac_coords[src] + return_dict["offsets"] + ) + return return_dict From 9e6919bef2b6d2ac8652cc40130cfe87e18bed06 Mon Sep 17 00:00:00 2001 From: "Gonzales, Carmelo" Date: Tue, 22 Oct 2024 10:46:05 -0700 Subject: [PATCH 2/4] fix: updating type hint --- matsciml/datasets/transforms/pbc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index 5d9ed271..3f6e24dd 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Literal + import numpy as np import torch from pymatgen.core import Lattice, Structure @@ -35,7 +37,7 @@ def __init__( self, cutoff_radius: float, adaptive_cutoff: bool = False, - backend: str = "pymatgen", + backend: Literal["pymatgen", "ase"] = "pymatgen", ) -> None: super().__init__() self.cutoff_radius = cutoff_radius From 4d5eb4c6010d3b788d76f4bb5d33f513dd336c42 Mon Sep 17 00:00:00 2001 From: "Gonzales, Carmelo" Date: Tue, 22 Oct 2024 10:47:11 -0700 Subject: [PATCH 3/4] fix: adding comment to note 3d pbc is hard coded. --- matsciml/datasets/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index c7090004..56184072 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -763,6 +763,7 @@ def calculate_ase_periodic_shifts(data, cutoff_radius, adaptive_cutoff): positions=data["pos"], numbers=data["atomic_numbers"], cell=cell.squeeze(0), + # Hard coding in the PBC direction for x, y, z. pbc=(True, True, True), ) cutoff = [cutoff_radius] * atoms.positions.shape[0] From a0af875f2042caa0b9bb9b0e511d9069f6aca351 Mon Sep 17 00:00:00 2001 From: "Gonzales, Carmelo" Date: Tue, 22 Oct 2024 13:20:19 -0700 Subject: [PATCH 4/4] feat: adding test to compare pbc calculation from pymatgen and ase --- matsciml/datasets/tests/test_transforms.py | 74 ++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/matsciml/datasets/tests/test_transforms.py b/matsciml/datasets/tests/test_transforms.py index 6f0c0302..6d9724f8 100644 --- a/matsciml/datasets/tests/test_transforms.py +++ b/matsciml/datasets/tests/test_transforms.py @@ -197,3 +197,77 @@ def test_ase_periodic(backend): batch = next(iter(loader)) # check if periodic properties transform was applied assert "unit_offsets" in batch + + +def test_pbc_backend_equivalence_easy(): + from ase.build import molecule + from pymatgen.io.ase import AseAtomsAdaptor + + atoms = molecule( + "H2O", cell=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], pbc=(True, True, True) + ) + structure = AseAtomsAdaptor.get_structure(atoms) + + data = {} + coords = torch.from_numpy(structure.cart_coords).float() + data["pos"] = coords + atom_numbers = torch.LongTensor(structure.atomic_numbers) + data["atomic_numbers"] = atom_numbers + data["natoms"] = len(atom_numbers) + lattice_params = torch.FloatTensor( + structure.lattice.abc + + tuple(a * (torch.pi / 180.0) for a in structure.lattice.angles), + ) + lattice_features = { + "lattice_params": lattice_params, + } + data["lattice_features"] = lattice_features + + ase_trans = transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend="ase" + ) + + pymatgen_trans = transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend="pymatgen" + ) + + ase_result = ase_trans(data) + pymatgen_result = pymatgen_trans(data) + + ase_wiring = torch.vstack([ase_result["src_nodes"], ase_result["dst_nodes"]]) + pymatgen_wiring = torch.vstack( + [pymatgen_result["src_nodes"], pymatgen_result["dst_nodes"]] + ) + equivalence = ase_wiring == pymatgen_wiring + # basically checking if src -> dst node wiring is equivalent between the two approaches + assert torch.all(equivalence) + + +def test_pbc_backend_equivalence_hard(): + ase_trans = transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend="ase" + ) + + pymatgen_trans = transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend="pymatgen" + ) + + dm = MatSciMLDataModule.from_devset( + "S2EFDataset", + batch_size=1, + ) + + dm.setup() + loader = dm.train_dataloader() + batch = next(iter(loader)) + batch["atomic_numbers"] = batch["atomic_numbers"].squeeze(0) + + ase_result = ase_trans(batch) + pymatgen_result = pymatgen_trans(batch) + ase_wiring = torch.vstack([ase_result["src_nodes"], ase_result["dst_nodes"]]) + pymatgen_wiring = torch.vstack( + [pymatgen_result["src_nodes"], pymatgen_result["dst_nodes"]] + ) + equivalence = ase_wiring == pymatgen_wiring + # basically checking if src -> dst node wiring is equivalent between the two approaches + assert torch.all(equivalence)