Skip to content

Commit

Permalink
Merge pull request #318 from laserkelvin/317-pbc-graph-wiring-options
Browse files Browse the repository at this point in the history
Specifiable options for periodic neighbors calculations
  • Loading branch information
laserkelvin authored Nov 18, 2024
2 parents b522337 + f1b14ac commit 028f44e
Show file tree
Hide file tree
Showing 4 changed files with 391 additions and 40 deletions.
27 changes: 16 additions & 11 deletions docs/source/best-practices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@ The way this is implemented in ``matsciml`` is to include the transform,

.. autofunction:: matsciml.datasets.transforms.PeriodicPropertiesTransform

This implementation is heavily based off
the tutorial outlined in the `e3nn documentation`_ where we use ``pymatgen``
to generate images, and for every atom in the graph,
compute nearest neighbors with some specified radius cutoff. One additional
detail we include in this approach is the ``adaptive_cutoff`` flag: if set to ``True``, will ensure
that all nodes are connected by gradually increasing the radius cutoff up
to a hard coded limit of 100 angstroms. This is intended to facilitate the
a small nominal cutoff, even if some data samples contain (intentionally)
significantly more distant atoms than the average sample. By doing so, we
improve computational efficiency by not needing to consider many more edges
than required.
This implementation was originally based off
the tutorial outlined in the `e3nn documentation`_. We initially provided
an implementation that uses `pymatgen` for the neighborhood calculation,
but have since extended it to use `ase` as well. We find that `ase` is
slightly less ambiguous with coordinate representations, but results from
the two can be mapped to yield the same behavior. In either case, the coordinates
and lattice parameters are passed into their respective backend representations
(i.e. ``ase.Atoms`` and ``pymatgen.Structure``), and subsequently used to
perform the neighborhood calculation to obtain source/destination node indices
for the edges, as well as their associated periodic image indices.

Below are descriptions of the two algorithms, and links to their source code.

.. autofunction:: matsciml.datasets.utils.calculate_periodic_shifts

.. autofunction:: matsciml.datasets.utils.calculate_ase_periodic_shifts

Point clouds to graphs
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
94 changes: 76 additions & 18 deletions matsciml/datasets/transforms/pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from pymatgen.core import Lattice, Structure
from loguru import logger

from matsciml.common.types import DataDict
from matsciml.datasets.transforms.base import AbstractDataTransform
Expand All @@ -18,31 +19,76 @@


class PeriodicPropertiesTransform(AbstractDataTransform):
"""
Rewires an already present graph to include periodic boundary conditions.
Since graphs are normally bounded within a unit cell, they may not capture
the necessary dependencies for atoms connected to neighboring cells. This
transform will compute the unit cell, tile it, and then rewire the graph
edges such that it can capture connectivity given a radial cutoff given
in Angstroms.
Cut off radius is specified in Angstroms. An additional flag, ``adaptive_cutoff``,
allows the cut off value to grow up to 100 angstroms in order to find neighbors.
This allows larger (typically unstable) structures to be modeled without applying
a large cut off for the entire dataset.
"""

def __init__(
self,
cutoff_radius: float,
adaptive_cutoff: bool = False,
backend: Literal["pymatgen", "ase"] = "pymatgen",
max_neighbors: int = 1000,
allow_self_loops: bool = False,
convert_to_unit_cell: bool = False,
is_cartesian: bool | None = None,
) -> None:
"""
Rewires an already present graph to include periodic boundary conditions.
Since graphs are normally bounded within a unit cell, they may not capture
the necessary dependencies for atoms connected to neighboring cells. This
transform will compute the unit cell, tile it, and then rewire the graph
edges such that it can capture connectivity given a radial cutoff given
in Angstroms.
Cut off radius is specified in Angstroms. An additional flag, ``adaptive_cutoff``,
allows the cut off value to grow up to 100 angstroms in order to find neighbors.
This allows larger (typically unstable) structures to be modeled without applying
a large cut off for the entire dataset.
Parameters
----------
cutoff_radius : float
Cutoff radius to use to truncate the neighbor list calculation.
adaptive_cutoff : bool, default False
If set to ``True``, will allow ``cutoff_radius`` to grow up to
30 angstroms if there are any disconnected subgraphs present.
This is to allow distant nodes to be captured in some structures
only as needed, keeping the computational requirements low for
other samples within a dataset.
backend : Literal['pymatgen', 'ase'], default 'pymatgen'
Which algorithm to use for the neighbor list calculation. Nominally
settings can be mapped to have the two produce equivalent results.
'pymatgen' is kept as the default, but at some point 'ase' will
become the default option. See the hosted documentation 'Best practices'
page for details.
max_neighbors : int, default 1000
Forcibly truncate the number of edges at any given node. Internally,
a counter is used to track the number of destination nodes when
looping over a node's neighbor list; when the counter exceeds this
value we immediately stop counting neighbors for the current node.
allow_self_loops : bool, default False
If ``True``, the edges will include self-interactions within the
original unit cell. If set to ``False``, these self-loops are
purged before returning edges.
convert_to_unit_cell : bool, default False
This argument is specific to ``pymatgen``, which is passed to the
``to_unit_cell`` argument during the ``Structure`` construction step.
is_cartesian : bool | None, default None
If set to ``None``, we will try and determine if the structure has
fractional coordinates as input or not. If a boolean is provided,
this is passed into the ``pymatgen.Structure`` construction step.
This is specific to ``pymatgen``, and is not used by ``ase``.
"""
super().__init__()
self.cutoff_radius = cutoff_radius
self.adaptive_cutoff = adaptive_cutoff
self.backend = backend
self.max_neighbors = max_neighbors
self.allow_self_loops = allow_self_loops
if is_cartesian is not None and backend == "ase":
logger.warning(
"`is_cartesian` passed but using `ase` backend; option will not affect anything."
)
self.is_cartesian = is_cartesian
self.convert_to_unit_cell = convert_to_unit_cell

def __call__(self, data: DataDict) -> DataDict:
"""
Expand Down Expand Up @@ -84,7 +130,10 @@ def __call__(self, data: DataDict) -> DataDict:
structure = data["structure"]
if isinstance(structure, Structure):
graph_props = calculate_periodic_shifts(
structure, self.cutoff_radius, self.adaptive_cutoff
structure,
self.cutoff_radius,
self.adaptive_cutoff,
max_neighbors=self.max_neighbors,
)
data.update(graph_props)
return data
Expand Down Expand Up @@ -123,16 +172,25 @@ def __call__(self, data: DataDict) -> DataDict:
data["atomic_numbers"],
data["pos"],
lattice=lattice,
convert_to_unit_cell=self.convert_to_unit_cell,
is_cartesian=self.is_cartesian,
)
if self.backend == "pymatgen":
graph_props = calculate_periodic_shifts(
structure, self.cutoff_radius, self.adaptive_cutoff
structure, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors
)
elif self.backend == "ase":
graph_props = calculate_ase_periodic_shifts(
data, self.cutoff_radius, self.adaptive_cutoff
data, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors
)
else:
raise RuntimeError(f"Requested backend f{self.backend} not available.")
data.update(graph_props)
if not self.allow_self_loops:
mask = data["src_nodes"] == data["dst_nodes"]
# only mask out self-loops within the same image
mask &= data["unit_offsets"].sum(dim=-1) == 0
# apply mask to each of the tensors that depend on edges
for key in ["src_nodes", "dst_nodes", "images", "unit_offsets", "offsets"]:
data[key] = data[key][mask]
return data
87 changes: 87 additions & 0 deletions matsciml/datasets/transforms/tests/test_pbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

from collections import Counter

import torch
import pytest
import numpy as np
from pymatgen.core import Structure, Lattice

from matsciml.datasets.transforms import PeriodicPropertiesTransform

"""
This module uses reference Materials project structures and tests
the edge calculation routines to ensure they at least work with
various parameters.
The key thing here is at least using feasible structures to perform
this check, rather than using randomly generated coordinates and
lattices, even if composing them isn't meaningful.
"""

hexa = Lattice.from_parameters(
4.81, 4.809999999999999, 13.12, 90.0, 90.0, 120.0, vesta=True
)
cubic = Lattice.from_parameters(6.79, 6.79, 12.63, 90.0, 90.0, 90.0, vesta=True)

# mp-1143
alumina = Structure(
hexa,
species=["Al", "O"],
coords=[[1 / 3, 2 / 3, 0.814571], [0.360521, 1 / 3, 0.583333]],
coords_are_cartesian=False,
)
# mp-1267
nac = Structure(
cubic,
species=["Na", "C"],
coords=[[0.688819, 3 / 4, 3 / 8], [0.065833, 0.565833, 0.0]],
coords_are_cartesian=False,
)


@pytest.mark.parametrize(
"coords",
[
alumina.cart_coords,
nac.cart_coords,
],
)
@pytest.mark.parametrize(
"cell",
[
hexa.matrix,
cubic.matrix,
],
)
@pytest.mark.parametrize("self_loops", [True, False])
@pytest.mark.parametrize("backend", ["pymatgen", "ase"])
@pytest.mark.parametrize(
"cutoff_radius", [6.0, 9.0, 15.0]
) # TODO figure out why pmg fails on 3
def test_periodic_generation(
coords: np.ndarray,
cell: np.ndarray,
self_loops: bool,
backend: str,
cutoff_radius: float,
):
coords = torch.FloatTensor(coords)
cell = torch.FloatTensor(cell)
transform = PeriodicPropertiesTransform(
cutoff_radius=cutoff_radius,
adaptive_cutoff=False,
backend=backend,
max_neighbors=10,
allow_self_loops=self_loops,
)
num_atoms = coords.size(0)
atomic_numbers = torch.ones(num_atoms)
packed_data = {"pos": coords, "cell": cell, "atomic_numbers": atomic_numbers}
output = transform(packed_data)
# check to make sure no source node has more than 10 neighbors
src_nodes = output["src_nodes"].tolist()
counts = Counter(src_nodes)
for index, count in counts.items():
if not self_loops:
assert count < 10, print(f"Node {index} has too many counts. {src_nodes}")
Loading

0 comments on commit 028f44e

Please sign in to comment.