diff --git a/matsciml/datasets/tests/test_transforms.py b/matsciml/datasets/tests/test_transforms.py index b17b8faa..6d9724f8 100644 --- a/matsciml/datasets/tests/test_transforms.py +++ b/matsciml/datasets/tests/test_transforms.py @@ -177,5 +177,97 @@ def test_graph_sorting(): ) dm.setup() loader = dm.train_dataloader() - graph = next(iter(loader))["graph"] + _ = next(iter(loader))["graph"] # not really anything to test, but just make sure it runs :D + + +@pytest.mark.parametrize("backend", ["ase", "pymatgen"]) +def test_ase_periodic(backend): + trans = [ + transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend=backend + ) + ] + dm = MatSciMLDataModule.from_devset( + "S2EFDataset", + dset_kwargs={"transforms": trans}, + ) + dm.setup() + loader = dm.train_dataloader() + batch = next(iter(loader)) + # check if periodic properties transform was applied + assert "unit_offsets" in batch + + +def test_pbc_backend_equivalence_easy(): + from ase.build import molecule + from pymatgen.io.ase import AseAtomsAdaptor + + atoms = molecule( + "H2O", cell=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], pbc=(True, True, True) + ) + structure = AseAtomsAdaptor.get_structure(atoms) + + data = {} + coords = torch.from_numpy(structure.cart_coords).float() + data["pos"] = coords + atom_numbers = torch.LongTensor(structure.atomic_numbers) + data["atomic_numbers"] = atom_numbers + data["natoms"] = len(atom_numbers) + lattice_params = torch.FloatTensor( + structure.lattice.abc + + tuple(a * (torch.pi / 180.0) for a in structure.lattice.angles), + ) + lattice_features = { + "lattice_params": lattice_params, + } + data["lattice_features"] = lattice_features + + ase_trans = transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend="ase" + ) + + pymatgen_trans = transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend="pymatgen" + ) + + ase_result = ase_trans(data) + pymatgen_result = pymatgen_trans(data) + + ase_wiring = torch.vstack([ase_result["src_nodes"], ase_result["dst_nodes"]]) + pymatgen_wiring = torch.vstack( + [pymatgen_result["src_nodes"], pymatgen_result["dst_nodes"]] + ) + equivalence = ase_wiring == pymatgen_wiring + # basically checking if src -> dst node wiring is equivalent between the two approaches + assert torch.all(equivalence) + + +def test_pbc_backend_equivalence_hard(): + ase_trans = transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend="ase" + ) + + pymatgen_trans = transforms.PeriodicPropertiesTransform( + cutoff_radius=6.0, adaptive_cutoff=True, backend="pymatgen" + ) + + dm = MatSciMLDataModule.from_devset( + "S2EFDataset", + batch_size=1, + ) + + dm.setup() + loader = dm.train_dataloader() + batch = next(iter(loader)) + batch["atomic_numbers"] = batch["atomic_numbers"].squeeze(0) + + ase_result = ase_trans(batch) + pymatgen_result = pymatgen_trans(batch) + ase_wiring = torch.vstack([ase_result["src_nodes"], ase_result["dst_nodes"]]) + pymatgen_wiring = torch.vstack( + [pymatgen_result["src_nodes"], pymatgen_result["dst_nodes"]] + ) + equivalence = ase_wiring == pymatgen_wiring + # basically checking if src -> dst node wiring is equivalent between the two approaches + assert torch.all(equivalence) diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index b8f9de0b..3f6e24dd 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -1,13 +1,16 @@ from __future__ import annotations -import torch +from typing import Literal + import numpy as np +import torch from pymatgen.core import Lattice, Structure from matsciml.common.types import DataDict from matsciml.datasets.transforms.base import AbstractDataTransform from matsciml.datasets.utils import ( calculate_periodic_shifts, + calculate_ase_periodic_shifts, make_pymatgen_periodic_structure, ) @@ -30,10 +33,16 @@ class PeriodicPropertiesTransform(AbstractDataTransform): a large cut off for the entire dataset. """ - def __init__(self, cutoff_radius: float, adaptive_cutoff: bool = False) -> None: + def __init__( + self, + cutoff_radius: float, + adaptive_cutoff: bool = False, + backend: Literal["pymatgen", "ase"] = "pymatgen", + ) -> None: super().__init__() self.cutoff_radius = cutoff_radius self.adaptive_cutoff = adaptive_cutoff + self.backend = backend def __call__(self, data: DataDict) -> DataDict: """ @@ -107,13 +116,23 @@ def __call__(self, data: DataDict) -> DataDict: tuple(angle * (180.0 / torch.pi) for angle in angles), ) lattice = Lattice.from_parameters(*abc, *angles, vesta=True) + # We need cell in data for ase backend. + data["cell"] = torch.tensor(lattice.matrix).unsqueeze(0).float() + structure = make_pymatgen_periodic_structure( data["atomic_numbers"], data["pos"], lattice=lattice, ) - graph_props = calculate_periodic_shifts( - structure, self.cutoff_radius, self.adaptive_cutoff - ) + if self.backend == "pymatgen": + graph_props = calculate_periodic_shifts( + structure, self.cutoff_radius, self.adaptive_cutoff + ) + elif self.backend == "ase": + graph_props = calculate_ase_periodic_shifts( + data, self.cutoff_radius, self.adaptive_cutoff + ) + else: + raise RuntimeError(f"Requested backend f{self.backend} not available.") data.update(graph_props) return data diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index ba9986e3..56184072 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -1,8 +1,10 @@ from __future__ import annotations import pickle +import ase from collections.abc import Generator from functools import lru_cache, partial +from ase.neighborlist import NeighborList from os import makedirs from pathlib import Path from typing import Any, Callable @@ -719,6 +721,7 @@ def _all_sites_have_neighbors(neighbors): f"No neighbors detected for structure with cutoff {cutoff}; {structure}" ) # process the neighbors now + all_src, all_dst, all_images = [], [], [] for src_idx, dst_sites in enumerate(neighbors): for site in dst_sites: @@ -751,3 +754,70 @@ def _all_sites_have_neighbors(neighbors): frac_coords[dst] - frac_coords[src] + return_dict["offsets"] ) return return_dict + + +def calculate_ase_periodic_shifts(data, cutoff_radius, adaptive_cutoff): + cell = data["cell"] + + atoms = ase.Atoms( + positions=data["pos"], + numbers=data["atomic_numbers"], + cell=cell.squeeze(0), + # Hard coding in the PBC direction for x, y, z. + pbc=(True, True, True), + ) + cutoff = [cutoff_radius] * atoms.positions.shape[0] + # Create a neighbor list + nl = NeighborList(cutoff, skin=0.0, self_interaction=False, bothways=True) + nl.update(atoms) + + neighbors = nl.nl.neighbors + + def _all_sites_have_neighbors(neighbors): + return all([len(n) for n in neighbors]) + + # if there are sites without neighbors and user requested adaptive + # cut off, we'll keep trying + if not _all_sites_have_neighbors(neighbors) and adaptive_cutoff: + while not _all_sites_have_neighbors(neighbors) and cutoff < 30.0: + # increment radial cutoff progressively + cutoff_radius += 0.5 + cutoff = [cutoff_radius] * atoms.positions.shape[0] + nl = NeighborList(cutoff, skin=0.0, self_interaction=False, bothways=True) + nl.update(atoms) + + # and we still don't find a neighbor, we have a problem with the structure + if not _all_sites_have_neighbors(neighbors): + raise ValueError(f"No neighbors detected for structure with cutoff {cutoff}") + + all_src, all_dst, all_images = [], [], [] + for src_idx in range(len(atoms)): + dst_index, image = nl.get_neighbors(src_idx) + for index in range(len(dst_index)): + all_src.append(src_idx) + all_dst.append(dst_index[index]) + all_images.append(image[index]) + + if any([len(obj) == 0 for obj in [all_images, all_dst, all_images]]): + raise ValueError( + f"No images or edges to work off for cutoff {cutoff}." + f" Please inspect your atoms object and neighbors: {atoms}." + ) + + frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float() + coords = torch.from_numpy(atoms.positions).float() + + return_dict = { + "src_nodes": torch.LongTensor(all_src), + "dst_nodes": torch.LongTensor(all_dst), + "images": torch.FloatTensor(all_images), + "cell": cell, + "pos": coords, + } + + return_dict["offsets"] = einsum(return_dict["images"], cell, "v i, n i j -> v j") + src, dst = return_dict["src_nodes"], return_dict["dst_nodes"] + return_dict["unit_offsets"] = ( + frac_coords[dst] - frac_coords[src] + return_dict["offsets"] + ) + return return_dict