diff --git a/matsciml/models/pyg/__init__.py b/matsciml/models/pyg/__init__.py index 71ead469..a6b334a4 100644 --- a/matsciml/models/pyg/__init__.py +++ b/matsciml/models/pyg/__init__.py @@ -18,16 +18,16 @@ # load models if we have PyG installed if _has_pyg: - from matsciml.models.pyg.cgcnn import CGCNN from matsciml.models.pyg.egnn import EGNN + from matsciml.models.pyg.faenet import FAENet from matsciml.models.pyg.mace import MACE, ScaleShiftMACE - __all__ = ["CGCNN", "EGNN", "FAENet", "MACE", "ScaleShiftMACE"] + __all__ = ["EGNN", "FAENet", "MACE", "ScaleShiftMACE"] # these packages need additional pyg dependencies if package_registry["torch_sparse"] and package_registry["torch_scatter"]: - from matsciml.models.pyg.dimenet import DimeNetWrap - from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap + from matsciml.models.pyg.dimenet import DimeNetWrap # noqa: F401 + from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap # noqa: F401 __all__.extend(["DimeNetWrap", "DimeNetPlusPlusWrap"]) else: @@ -35,11 +35,11 @@ "Missing torch_sparse and torch_scatter; DimeNet models will not be available." ) if package_registry["torch_scatter"]: - from matsciml.models.pyg.forcenet import ForceNet - from matsciml.models.pyg.schnet import SchNetWrap - from matsciml.models.pyg.faenet import FAENet + from matsciml.models.pyg.forcenet import ForceNet # noqa: F401 + from matsciml.models.pyg.schnet import SchNetWrap # noqa: F401 + from matsciml.models.pyg.cgcnn import CGCNN # noqa: F401 - __all__.extend(["ForceNet", "SchNetWrap", "FAENet"]) + __all__.extend(["ForceNet", "SchNetWrap", "FAENet", "CGCNN"]) else: logger.warning( "Missing torch_scatter; ForceNet, SchNet, and FAENet models will not be available." diff --git a/matsciml/models/pyg/faenet/layers.py b/matsciml/models/pyg/faenet/layers.py index db290bec..5747d252 100644 --- a/matsciml/models/pyg/faenet/layers.py +++ b/matsciml/models/pyg/faenet/layers.py @@ -1,18 +1,15 @@ from __future__ import annotations -from typing import Tuple, Union - import pandas as pd import torch import torch.nn as nn from mendeleev.fetch import fetch_ionization_energies, fetch_table -from torch import nn from torch.nn import Embedding, Linear from torch_geometric.nn import MessagePassing from torch_geometric.nn.norm import GraphNorm -from torch_scatter import scatter from matsciml.models.pyg.faenet.helper import * +from matsciml.models.pyg.scatter import scatter_sum class PhysEmbedding(nn.Module): @@ -508,7 +505,7 @@ def forward( h = h * alpha # Global pooling - out = scatter(h, batch, dim=0, reduce="add") + out = scatter_sum(h, batch, dim=0) return out diff --git a/matsciml/models/pyg/scatter.py b/matsciml/models/pyg/scatter.py new file mode 100644 index 00000000..555fef98 --- /dev/null +++ b/matsciml/models/pyg/scatter.py @@ -0,0 +1,178 @@ +########################################################################################### +# Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# (https://github.com/ACEsuit/mace) +# Original Authors: Ilyes Batatia, Gregor Simm +# Integrated into matsciml by Vaibhav Bihani, Sajid Mannan +# Refactors and improved docstrings by Kelvin Lee +# This program is distributed under the MIT License +########################################################################################### +"""basic scatter_sum operations from torch_scatter from +https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py +Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. +PyTorch plans to move these features into the main repo, but until then, +to make installation simpler, we need this pure python set of wrappers +that don't require installing PyTorch C++ extensions. +See https://github.com/pytorch/pytorch/issues/63780. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +__all__ = ["scatter_sum", "scatter_std", "scatter_mean"] + + +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int) -> torch.Tensor: + """ + Broadcasts ``src`` to yield a tensor with equivalent shape to ``other`` + along dimension ``dim``. + + Parameters + ---------- + src : torch.Tensor + Tensor to broadcast into a new shape. + other : torch.Tensor + Tensor to match shape against. + dim : int + Dimension to broadcast values along. + + Returns + ------- + torch.Tensor + Broadcasted values of ``src``, with the same shape as ``other``. + """ + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +@torch.jit.script +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: torch.Tensor | None = None, + dim_size: int | None = None, + reduce: str = "sum", +) -> torch.Tensor: + """ + Apply a scatter operation with sum reduction, from ``src`` + to ``out`` at indices ``index`` along the specified + dimension. + + The function will apply a ``_broadcast`` with ``index`` + to reshape it to the same as ``src`` first, then allocate + a new tensor based on the expected final shape (depending + on ``dim``). + + Parameters + ---------- + src : torch.Tensor + Tensor containing source values to scatter add. + index : torch.Tensor + Indices for the scatter add operation. + dim : int, optional + Dimension to apply the scatter add operation, by default -1 + out : torch.Tensor, optional + Output tensor to store the scatter sum result, by default None, + which will create a tensor with the correct shape within + this function. + dim_size : int, optional + Used to determine the output shape, by default None, which + will then infer the output shape from ``dim``. + reduce : str, optional + Unused and kept for backwards compatibility. + + Returns + ------- + torch.Tensor + Resulting scatter sum output. + """ + assert reduce == "sum" # for now, TODO + index = _broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +@torch.jit.script +def scatter_std( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True, +) -> torch.Tensor: + if out is not None: + dim_size = out.size(dim) + + if dim < 0: + dim = src.dim() + dim + + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clamp(1) + mean = tmp.div(count) + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out, dim_size) + + if unbiased: + count = count.sub(1).clamp_(1) + out = out.div(count + 1e-6).sqrt() + + return out + + +@torch.jit.script +def scatter_mean( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = _broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode="floor") + return out diff --git a/matsciml/models/pyg/schnet.py b/matsciml/models/pyg/schnet.py index b30474a2..e736fa1d 100644 --- a/matsciml/models/pyg/schnet.py +++ b/matsciml/models/pyg/schnet.py @@ -9,7 +9,7 @@ import torch from torch_geometric.nn import SchNet -from torch_scatter import scatter +from matsciml.models.pyg.scatter import scatter_sum, scatter_mean from matsciml.common.utils import conditional_grad, get_pbc_distances, radius_graph_pbc @@ -81,6 +81,11 @@ def __init__( cutoff=cutoff, readout=readout, ) + # map literal readout choice to functions + if readout == "add": + self.readout = scatter_sum + else: + self.readout = scatter_mean @conditional_grad(torch.enable_grad()) def _forward(self, data): @@ -124,7 +129,7 @@ def _forward(self, data): h = self.lin2(h) batch = torch.zeros_like(z) if batch is None else batch - energy = scatter(h, batch, dim=0, reduce=self.readout) + energy = self.readout(h, batch, dim=0) else: energy = super().forward(z, pos, batch) return energy