Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASE Backend For PBC Computation #310

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion matsciml/datasets/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,23 @@ def test_graph_sorting():
)
dm.setup()
loader = dm.train_dataloader()
graph = next(iter(loader))["graph"]
_ = next(iter(loader))["graph"]
# not really anything to test, but just make sure it runs :D


@pytest.mark.parametrize("backend", ["ase", "pymatgen"])
def test_ase_periodic(backend):
trans = [
transforms.PeriodicPropertiesTransform(
cutoff_radius=6.0, adaptive_cutoff=True, backend=backend
)
]
dm = MatSciMLDataModule.from_devset(
"S2EFDataset",
dset_kwargs={"transforms": trans},
)
dm.setup()
loader = dm.train_dataloader()
batch = next(iter(loader))
# check if periodic properties transform was applied
assert "unit_offsets" in batch
27 changes: 22 additions & 5 deletions matsciml/datasets/transforms/pbc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import torch
import numpy as np
import torch
from pymatgen.core import Lattice, Structure

from matsciml.common.types import DataDict
from matsciml.datasets.transforms.base import AbstractDataTransform
from matsciml.datasets.utils import (
calculate_periodic_shifts,
calculate_ase_periodic_shifts,
make_pymatgen_periodic_structure,
)

Expand All @@ -30,10 +31,16 @@ class PeriodicPropertiesTransform(AbstractDataTransform):
a large cut off for the entire dataset.
"""

def __init__(self, cutoff_radius: float, adaptive_cutoff: bool = False) -> None:
def __init__(
self,
cutoff_radius: float,
adaptive_cutoff: bool = False,
backend: str = "pymatgen",
laserkelvin marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
super().__init__()
self.cutoff_radius = cutoff_radius
self.adaptive_cutoff = adaptive_cutoff
self.backend = backend

def __call__(self, data: DataDict) -> DataDict:
"""
Expand Down Expand Up @@ -107,13 +114,23 @@ def __call__(self, data: DataDict) -> DataDict:
tuple(angle * (180.0 / torch.pi) for angle in angles),
)
lattice = Lattice.from_parameters(*abc, *angles, vesta=True)
# We need cell in data for ase backend.
data["cell"] = torch.tensor(lattice.matrix).unsqueeze(0).float()

structure = make_pymatgen_periodic_structure(
data["atomic_numbers"],
data["pos"],
lattice=lattice,
)
graph_props = calculate_periodic_shifts(
structure, self.cutoff_radius, self.adaptive_cutoff
)
if self.backend == "pymatgen":
graph_props = calculate_periodic_shifts(
structure, self.cutoff_radius, self.adaptive_cutoff
)
elif self.backend == "ase":
graph_props = calculate_ase_periodic_shifts(
data, self.cutoff_radius, self.adaptive_cutoff
)
else:
raise RuntimeError(f"Requested backend f{self.backend} not available.")
data.update(graph_props)
return data
69 changes: 69 additions & 0 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import pickle
import ase
from collections.abc import Generator
from functools import lru_cache, partial
from ase.neighborlist import NeighborList
from os import makedirs
from pathlib import Path
from typing import Any, Callable
Expand Down Expand Up @@ -719,6 +721,7 @@ def _all_sites_have_neighbors(neighbors):
f"No neighbors detected for structure with cutoff {cutoff}; {structure}"
)
# process the neighbors now

all_src, all_dst, all_images = [], [], []
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
Expand Down Expand Up @@ -751,3 +754,69 @@ def _all_sites_have_neighbors(neighbors):
frac_coords[dst] - frac_coords[src] + return_dict["offsets"]
)
return return_dict


def calculate_ase_periodic_shifts(data, cutoff_radius, adaptive_cutoff):
cell = data["cell"]

atoms = ase.Atoms(
positions=data["pos"],
numbers=data["atomic_numbers"],
cell=cell.squeeze(0),
pbc=(True, True, True),
laserkelvin marked this conversation as resolved.
Show resolved Hide resolved
)
cutoff = [cutoff_radius] * atoms.positions.shape[0]
# Create a neighbor list
nl = NeighborList(cutoff, skin=0.0, self_interaction=False, bothways=True)
nl.update(atoms)

neighbors = nl.nl.neighbors

def _all_sites_have_neighbors(neighbors):
return all([len(n) for n in neighbors])

# if there are sites without neighbors and user requested adaptive
# cut off, we'll keep trying
if not _all_sites_have_neighbors(neighbors) and adaptive_cutoff:
while not _all_sites_have_neighbors(neighbors) and cutoff < 30.0:
# increment radial cutoff progressively
cutoff_radius += 0.5
cutoff = [cutoff_radius] * atoms.positions.shape[0]
nl = NeighborList(cutoff, skin=0.0, self_interaction=False, bothways=True)
nl.update(atoms)

# and we still don't find a neighbor, we have a problem with the structure
if not _all_sites_have_neighbors(neighbors):
raise ValueError(f"No neighbors detected for structure with cutoff {cutoff}")

all_src, all_dst, all_images = [], [], []
for src_idx in range(len(atoms)):
dst_index, image = nl.get_neighbors(src_idx)
for index in range(len(dst_index)):
all_src.append(src_idx)
all_dst.append(dst_index[index])
all_images.append(image[index])

if any([len(obj) == 0 for obj in [all_images, all_dst, all_images]]):
raise ValueError(
f"No images or edges to work off for cutoff {cutoff}."
f" Please inspect your atoms object and neighbors: {atoms}."
)

frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float()
coords = torch.from_numpy(atoms.positions).float()

return_dict = {
"src_nodes": torch.LongTensor(all_src),
"dst_nodes": torch.LongTensor(all_dst),
"images": torch.FloatTensor(all_images),
"cell": cell,
"pos": coords,
}

return_dict["offsets"] = einsum(return_dict["images"], cell, "v i, n i j -> v j")
src, dst = return_dict["src_nodes"], return_dict["dst_nodes"]
return_dict["unit_offsets"] = (
frac_coords[dst] - frac_coords[src] + return_dict["offsets"]
)
return return_dict
Loading