From 94ab60aac8813dd0430cd7d1256c021d7acad1da Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Tue, 29 Oct 2024 14:38:06 -0700 Subject: [PATCH 1/6] refactor: duplicating general use scatter from mace modules Signed-off-by: Lee, Kin Long Kelvin --- matsciml/models/pyg/scatter.py | 126 +++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 matsciml/models/pyg/scatter.py diff --git a/matsciml/models/pyg/scatter.py b/matsciml/models/pyg/scatter.py new file mode 100644 index 00000000..963e9f55 --- /dev/null +++ b/matsciml/models/pyg/scatter.py @@ -0,0 +1,126 @@ +########################################################################################### +# Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# (https://github.com/ACEsuit/mace) +# Original Authors: Ilyes Batatia, Gregor Simm +# Integrated into matsciml by Vaibhav Bihani, Sajid Mannan +# This program is distributed under the MIT License +########################################################################################### +"""basic scatter_sum operations from torch_scatter from +https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py +Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. +PyTorch plans to move these features into the main repo, but until then, +to make installation simpler, we need this pure python set of wrappers +that don't require installing PyTorch C++ extensions. +See https://github.com/pytorch/pytorch/issues/63780. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +__all__ = ["scatter_sum", "scatter_std", "scatter_mean"] + + +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +@torch.jit.script +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> torch.Tensor: + assert reduce == "sum" # for now, TODO + index = _broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +@torch.jit.script +def scatter_std( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True, +) -> torch.Tensor: + if out is not None: + dim_size = out.size(dim) + + if dim < 0: + dim = src.dim() + dim + + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clamp(1) + mean = tmp.div(count) + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out, dim_size) + + if unbiased: + count = count.sub(1).clamp_(1) + out = out.div(count + 1e-6).sqrt() + + return out + + +@torch.jit.script +def scatter_mean( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = _broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode="floor") + return out From c682bc6eded29e942c42bce2dfef141de4bae4ec Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Tue, 29 Oct 2024 14:40:58 -0700 Subject: [PATCH 2/6] refactor: using scatter sum and mean for SchNet Signed-off-by: Lee, Kin Long Kelvin --- matsciml/models/pyg/schnet.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/matsciml/models/pyg/schnet.py b/matsciml/models/pyg/schnet.py index b30474a2..e736fa1d 100644 --- a/matsciml/models/pyg/schnet.py +++ b/matsciml/models/pyg/schnet.py @@ -9,7 +9,7 @@ import torch from torch_geometric.nn import SchNet -from torch_scatter import scatter +from matsciml.models.pyg.scatter import scatter_sum, scatter_mean from matsciml.common.utils import conditional_grad, get_pbc_distances, radius_graph_pbc @@ -81,6 +81,11 @@ def __init__( cutoff=cutoff, readout=readout, ) + # map literal readout choice to functions + if readout == "add": + self.readout = scatter_sum + else: + self.readout = scatter_mean @conditional_grad(torch.enable_grad()) def _forward(self, data): @@ -124,7 +129,7 @@ def _forward(self, data): h = self.lin2(h) batch = torch.zeros_like(z) if batch is None else batch - energy = scatter(h, batch, dim=0, reduce=self.readout) + energy = self.readout(h, batch, dim=0) else: energy = super().forward(z, pos, batch) return energy From 7be6549c8ac13ed70cddaab7828c1410de5c8999 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Tue, 29 Oct 2024 14:42:36 -0700 Subject: [PATCH 3/6] refactor: replacing torch scatter call in FAENet --- matsciml/models/pyg/faenet/layers.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/matsciml/models/pyg/faenet/layers.py b/matsciml/models/pyg/faenet/layers.py index db290bec..5747d252 100644 --- a/matsciml/models/pyg/faenet/layers.py +++ b/matsciml/models/pyg/faenet/layers.py @@ -1,18 +1,15 @@ from __future__ import annotations -from typing import Tuple, Union - import pandas as pd import torch import torch.nn as nn from mendeleev.fetch import fetch_ionization_energies, fetch_table -from torch import nn from torch.nn import Embedding, Linear from torch_geometric.nn import MessagePassing from torch_geometric.nn.norm import GraphNorm -from torch_scatter import scatter from matsciml.models.pyg.faenet.helper import * +from matsciml.models.pyg.scatter import scatter_sum class PhysEmbedding(nn.Module): @@ -508,7 +505,7 @@ def forward( h = h * alpha # Global pooling - out = scatter(h, batch, dim=0, reduce="add") + out = scatter_sum(h, batch, dim=0) return out From 5e01f7ea9dca01a001b7892a68584f2d8cfd8119 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Tue, 29 Oct 2024 15:01:23 -0700 Subject: [PATCH 4/6] refactor: putting CGCNN in needs torch_scatter category --- matsciml/models/pyg/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/matsciml/models/pyg/__init__.py b/matsciml/models/pyg/__init__.py index 71ead469..53261f93 100644 --- a/matsciml/models/pyg/__init__.py +++ b/matsciml/models/pyg/__init__.py @@ -18,16 +18,15 @@ # load models if we have PyG installed if _has_pyg: - from matsciml.models.pyg.cgcnn import CGCNN from matsciml.models.pyg.egnn import EGNN from matsciml.models.pyg.mace import MACE, ScaleShiftMACE - __all__ = ["CGCNN", "EGNN", "FAENet", "MACE", "ScaleShiftMACE"] + __all__ = ["EGNN", "FAENet", "MACE", "ScaleShiftMACE"] # these packages need additional pyg dependencies if package_registry["torch_sparse"] and package_registry["torch_scatter"]: - from matsciml.models.pyg.dimenet import DimeNetWrap - from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap + from matsciml.models.pyg.dimenet import DimeNetWrap # noqa: F401 + from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap # noqa: F401 __all__.extend(["DimeNetWrap", "DimeNetPlusPlusWrap"]) else: @@ -35,11 +34,12 @@ "Missing torch_sparse and torch_scatter; DimeNet models will not be available." ) if package_registry["torch_scatter"]: - from matsciml.models.pyg.forcenet import ForceNet - from matsciml.models.pyg.schnet import SchNetWrap - from matsciml.models.pyg.faenet import FAENet + from matsciml.models.pyg.forcenet import ForceNet # noqa: F401 + from matsciml.models.pyg.schnet import SchNetWrap # noqa: F401 + from matsciml.models.pyg.faenet import FAENet # noqa: F401 + from matsciml.models.pyg.cgcnn import CGCNN # noqa: F401 - __all__.extend(["ForceNet", "SchNetWrap", "FAENet"]) + __all__.extend(["ForceNet", "SchNetWrap", "FAENet", "CGCNN"]) else: logger.warning( "Missing torch_scatter; ForceNet, SchNet, and FAENet models will not be available." From f57d8652588019f28c9bb1da426429ddf7e5db1a Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Tue, 29 Oct 2024 15:02:38 -0700 Subject: [PATCH 5/6] refactor: moving FAENet into base pyg category --- matsciml/models/pyg/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/models/pyg/__init__.py b/matsciml/models/pyg/__init__.py index 53261f93..a6b334a4 100644 --- a/matsciml/models/pyg/__init__.py +++ b/matsciml/models/pyg/__init__.py @@ -19,6 +19,7 @@ # load models if we have PyG installed if _has_pyg: from matsciml.models.pyg.egnn import EGNN + from matsciml.models.pyg.faenet import FAENet from matsciml.models.pyg.mace import MACE, ScaleShiftMACE __all__ = ["EGNN", "FAENet", "MACE", "ScaleShiftMACE"] @@ -36,7 +37,6 @@ if package_registry["torch_scatter"]: from matsciml.models.pyg.forcenet import ForceNet # noqa: F401 from matsciml.models.pyg.schnet import SchNetWrap # noqa: F401 - from matsciml.models.pyg.faenet import FAENet # noqa: F401 from matsciml.models.pyg.cgcnn import CGCNN # noqa: F401 __all__.extend(["ForceNet", "SchNetWrap", "FAENet", "CGCNN"]) From 6e6ccd121047a09b7844ebc5cf97a37997a37876 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 8 Nov 2024 13:39:37 -0800 Subject: [PATCH 6/6] docs: adding docstrings to scatter ops Signed-off-by: Lee, Kin Long Kelvin --- matsciml/models/pyg/scatter.py | 58 ++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/matsciml/models/pyg/scatter.py b/matsciml/models/pyg/scatter.py index 963e9f55..555fef98 100644 --- a/matsciml/models/pyg/scatter.py +++ b/matsciml/models/pyg/scatter.py @@ -3,6 +3,7 @@ # (https://github.com/ACEsuit/mace) # Original Authors: Ilyes Batatia, Gregor Simm # Integrated into matsciml by Vaibhav Bihani, Sajid Mannan +# Refactors and improved docstrings by Kelvin Lee # This program is distributed under the MIT License ########################################################################################### """basic scatter_sum operations from torch_scatter from @@ -23,7 +24,25 @@ __all__ = ["scatter_sum", "scatter_std", "scatter_mean"] -def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int) -> torch.Tensor: + """ + Broadcasts ``src`` to yield a tensor with equivalent shape to ``other`` + along dimension ``dim``. + + Parameters + ---------- + src : torch.Tensor + Tensor to broadcast into a new shape. + other : torch.Tensor + Tensor to match shape against. + dim : int + Dimension to broadcast values along. + + Returns + ------- + torch.Tensor + Broadcasted values of ``src``, with the same shape as ``other``. + """ if dim < 0: dim = other.dim() + dim if src.dim() == 1: @@ -40,10 +59,43 @@ def scatter_sum( src: torch.Tensor, index: torch.Tensor, dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, + out: torch.Tensor | None = None, + dim_size: int | None = None, reduce: str = "sum", ) -> torch.Tensor: + """ + Apply a scatter operation with sum reduction, from ``src`` + to ``out`` at indices ``index`` along the specified + dimension. + + The function will apply a ``_broadcast`` with ``index`` + to reshape it to the same as ``src`` first, then allocate + a new tensor based on the expected final shape (depending + on ``dim``). + + Parameters + ---------- + src : torch.Tensor + Tensor containing source values to scatter add. + index : torch.Tensor + Indices for the scatter add operation. + dim : int, optional + Dimension to apply the scatter add operation, by default -1 + out : torch.Tensor, optional + Output tensor to store the scatter sum result, by default None, + which will create a tensor with the correct shape within + this function. + dim_size : int, optional + Used to determine the output shape, by default None, which + will then infer the output shape from ``dim``. + reduce : str, optional + Unused and kept for backwards compatibility. + + Returns + ------- + torch.Tensor + Resulting scatter sum output. + """ assert reduce == "sum" # for now, TODO index = _broadcast(index, src, dim) if out is None: