diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 7380d036c..c070fea4e 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -189,6 +189,10 @@ def no_weight_decay(self) -> list: class HeadInterface(metaclass=ABCMeta): + @property + def use_amp(self): + return False + @abstractmethod def forward( self, data: Batch, emb: dict[str, torch.Tensor] @@ -249,6 +253,7 @@ def __init__( ): super().__init__() self.otf_graph = otf_graph + self.device = "cpu" # make a copy so we don't modify the original config backbone = copy.deepcopy(backbone) heads = copy.deepcopy(heads) @@ -279,12 +284,20 @@ def __init__( self.output_heads = torch.nn.ModuleDict(self.output_heads) + def to(self, *args, **kwargs): + if "device" in kwargs: + self.device = kwargs["device"] + return super().to(*args, **kwargs) + def forward(self, data: Batch): emb = self.backbone(data) # Predict all output properties for all structures in the batch for now. out = {} for k in self.output_heads: - out.update(self.output_heads[k](data, emb)) + with torch.autocast( + device_type=self.device, enabled=self.output_heads[k].use_amp + ): + out.update(self.output_heads[k](data, emb)) return out diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index b30394560..b78f43597 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -3,6 +3,7 @@ import contextlib import logging import math +from functools import partial import torch import torch.nn as nn @@ -54,6 +55,28 @@ _AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100 +def eqv2_init_weights(m, weight_init): + if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + if weight_init == "normal": + std = 1 / math.sqrt(m.in_features) + torch.nn.init.normal_(m.weight, 0, std) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + elif isinstance(m, RadialFunction): + m.apply(eqv2_uniform_init_linear_weights) + + +def eqv2_uniform_init_linear_weights(m): + if isinstance(m, torch.nn.Linear): + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + std = 1 / math.sqrt(m.in_features) + torch.nn.init.uniform_(m.weight, -std, std) + + @registry.register_model("equiformer_v2") class EquiformerV2(nn.Module, GraphModelMixin): """ @@ -400,8 +423,7 @@ def __init__( requires_grad=False, ) - self.apply(self._init_weights) - self.apply(self._uniform_init_rad_func_linear_weights) + self.apply(partial(eqv2_init_weights, weight_init=self.weight_init)) def _init_gp_partitions( self, @@ -630,31 +652,6 @@ def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): def num_params(self): return sum(p.numel() for p in self.parameters()) - def _init_weights(self, m): - if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): - if m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - if self.weight_init == "normal": - std = 1 / math.sqrt(m.in_features) - torch.nn.init.normal_(m.weight, 0, std) - elif self.weight_init == "uniform": - self._uniform_init_linear_weights(m) - - elif isinstance(m, torch.nn.LayerNorm): - torch.nn.init.constant_(m.bias, 0) - torch.nn.init.constant_(m.weight, 1.0) - - def _uniform_init_rad_func_linear_weights(self, m): - if isinstance(m, RadialFunction): - m.apply(self._uniform_init_linear_weights) - - def _uniform_init_linear_weights(self, m): - if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): - if m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - std = 1 / math.sqrt(m.in_features) - torch.nn.init.uniform_(m.weight, -std, std) - @torch.jit.ignore def no_weight_decay(self) -> set: no_wd_list = [] @@ -852,8 +849,7 @@ def __init__(self, backbone): backbone.use_grid_mlp, backbone.use_sep_s2_act, ) - self.apply(backbone._init_weights) - self.apply(backbone._uniform_init_rad_func_linear_weights) + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): node_energy = self.energy_block(emb["node_embedding"]) @@ -898,8 +894,7 @@ def __init__(self, backbone): backbone.use_sep_s2_act, alpha_drop=0.0, ) - self.apply(backbone._init_weights) - self.apply(backbone._uniform_init_rad_func_linear_weights) + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) def forward(self, data: Batch, emb: dict[str, torch.Tensor]): if self.activation_checkpoint: diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py new file mode 100644 index 000000000..7542c0d13 --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .rank2 import Rank2SymmetricTensorHead + +__all__ = ["Rank2SymmetricTensorHead"] diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py new file mode 100644 index 000000000..2bbf42eaa --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -0,0 +1,351 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from functools import partial + +import torch +from e3nn import o3 +from torch import nn +from torch_scatter import scatter + +from fairchem.core.common.registry import registry +from fairchem.core.models.base import BackboneInterface, HeadInterface +from fairchem.core.models.equiformer_v2.equiformer_v2 import eqv2_init_weights +from fairchem.core.models.equiformer_v2.layer_norm import get_normalization_layer + + +class Rank2Block(nn.Module): + """ + Output block for predicting rank-2 tensors (stress, dielectric tensor). + Applies outer product between edges and computes node-wise or edge-wise MLP. + + Args: + emb_size (int): Size of edge embedding used to compute outer product + num_layers (int): Number of layers of the MLP + edge_level (bool): If true apply MLP at edge level before pooling, otherwise use MLP at nodes after pooling + extensive (bool): Whether to sum or average the outer products + """ + + def __init__( + self, + emb_size: int, + num_layers: int = 2, + edge_level: bool = False, + extensive: bool = False, + ): + super().__init__() + + self.edge_level = edge_level + self.emb_size = emb_size + self.extensive = extensive + self.scalar_nonlinearity = nn.SiLU() + self.r2tensor_MLP = nn.Sequential() + for i in range(num_layers): + if i < num_layers - 1: + self.r2tensor_MLP.append(nn.Linear(emb_size, emb_size)) + self.r2tensor_MLP.append(self.scalar_nonlinearity) + else: + self.r2tensor_MLP.append(nn.Linear(emb_size, 1)) + + def forward(self, edge_distance_vec, x_edge, edge_index, data): + """ + Args: + edge_distance_vec (torch.Tensor): Tensor of shape (..., 3) + x_edge (torch.Tensor): Tensor of shape (..., emb_size) + edge_index (torch.Tensor): Tensor of shape (2, nEdges) + data: LMDBDataset sample + """ + + outer_product_edge = torch.bmm( + edge_distance_vec.unsqueeze(2), edge_distance_vec.unsqueeze(1) + ) + + edge_outer = ( + x_edge[:, :, None] * outer_product_edge.view(-1, 9)[:, None, :] + ) # should end up as 2400 x 128 x 9 + + # edge_outer: (nEdges, emb_size_edge, 9) + if self.edge_level: + # MLP at edge level before pooling. + edge_outer = edge_outer.transpose(1, 2) # (nEdges, 9, emb_size_edge) + edge_outer = self.r2tensor_MLP(edge_outer) # (nEdges, 9, 1) + edge_outer = edge_outer.reshape(-1, 9) # (nEdges, 9) + + node_outer = scatter(edge_outer, edge_index, dim=0, reduce="mean") + else: + # operates at edge level before mixing / MLP => mixing / MLP happens at node level + node_outer = scatter(edge_outer, edge_index, dim=0, reduce="mean") + + node_outer = node_outer.transpose(1, 2) # (natoms, 9, emb_size_edge) + node_outer = self.r2tensor_MLP(node_outer) # (natoms, 9, 1) + node_outer = node_outer.reshape(-1, 9) # (natoms, 9) + + # node_outer: nAtoms, 9 => average across all atoms at the structure level + if self.extensive: + r2_tensor = scatter(node_outer, data.batch, dim=0, reduce="sum") + else: + r2_tensor = scatter(node_outer, data.batch, dim=0, reduce="mean") + return r2_tensor + + +class Rank2DecompositionEdgeBlock(nn.Module): + """ + Output block for predicting rank-2 tensors (stress, dielectric tensor, etc). + Decomposes a rank-2 symmetric tensor into irrep degree 0 and 2. + + Args: + emb_size (int): Size of edge embedding used to compute outer product + num_layers (int): Number of layers of the MLP + edge_level (bool): If true apply MLP at edge level before pooling, otherwise use MLP at nodes after pooling + extensive (bool): Whether to sum or average the outer products + """ + + def __init__( + self, + emb_size: int, + num_layers: int = 2, + edge_level: bool = False, + extensive: bool = False, + ): + super().__init__() + self.emb_size = emb_size + self.edge_level = edge_level + self.extensive = extensive + self.scalar_nonlinearity = nn.SiLU() + self.scalar_MLP = nn.Sequential() + self.irrep2_MLP = nn.Sequential() + for i in range(num_layers): + if i < num_layers - 1: + self.scalar_MLP.append(nn.Linear(emb_size, emb_size)) + self.irrep2_MLP.append(nn.Linear(emb_size, emb_size)) + self.scalar_MLP.append(self.scalar_nonlinearity) + self.irrep2_MLP.append(self.scalar_nonlinearity) + else: + self.scalar_MLP.append(nn.Linear(emb_size, 1)) + self.irrep2_MLP.append(nn.Linear(emb_size, 1)) + + # Change of basis obtained by stacking the C-G coefficients + self.change_mat = torch.transpose( + torch.tensor( + [ + [3 ** (-0.5), 0, 0, 0, 3 ** (-0.5), 0, 0, 0, 3 ** (-0.5)], + [0, 0, 0, 0, 0, 2 ** (-0.5), 0, -(2 ** (-0.5)), 0], + [0, 0, -(2 ** (-0.5)), 0, 0, 0, 2 ** (-0.5), 0, 0], + [0, 2 ** (-0.5), 0, -(2 ** (-0.5)), 0, 0, 0, 0, 0], + [0, 0, 0.5**0.5, 0, 0, 0, 0.5**0.5, 0, 0], + [0, 2 ** (-0.5), 0, 2 ** (-0.5), 0, 0, 0, 0, 0], + [ + -(6 ** (-0.5)), + 0, + 0, + 0, + 2 * 6 ** (-0.5), + 0, + 0, + 0, + -(6 ** (-0.5)), + ], + [0, 0, 0, 0, 0, 2 ** (-0.5), 0, 2 ** (-0.5), 0], + [-(2 ** (-0.5)), 0, 0, 0, 0, 0, 0, 0, 2 ** (-0.5)], + ] + ).detach(), + 0, + 1, + ) + + def forward(self, edge_distance_vec, x_edge, edge_index, data): + """ + Args: + edge_distance_vec (torch.Tensor): Tensor of shape (..., 3) + x_edge (torch.Tensor): Tensor of shape (..., emb_size) + edge_index (torch.Tensor): Tensor of shape (2, nEdges) + data: LMDBDataset sample + """ + # Calculate spherical harmonics of degree 2 of the points sampled + sphere_irrep2 = o3.spherical_harmonics( + 2, edge_distance_vec, True + ).detach() # (nEdges, 5) + + if self.edge_level: + # MLP at edge level before pooling. + + # Irrep 0 prediction + edge_scalar = x_edge + edge_scalar = self.scalar_MLP(edge_scalar) + + # Irrep 2 prediction + edge_irrep2 = ( + sphere_irrep2[:, :, None] * x_edge[:, None, :] + ) # (nEdges, 5, emb_size) + edge_irrep2 = self.irrep2_MLP(edge_irrep2) + + node_scalar = scatter(edge_scalar, edge_index, dim=0, reduce="mean") + node_irrep2 = scatter(edge_irrep2, edge_index, dim=0, reduce="mean") + else: + edge_irrep2 = ( + sphere_irrep2[:, :, None] * x_edge[:, None, :] + ) # (nAtoms, 5, emb_size) + + node_scalar = scatter(x_edge, edge_index, dim=0, reduce="mean") + node_irrep2 = scatter(edge_irrep2, edge_index, dim=0, reduce="mean") + + # Irrep 0 prediction + for module in self.scalar_MLP: + node_scalar = module(node_scalar) + + # Irrep 2 prediction + for module in self.irrep2_MLP: + node_irrep2 = module(node_irrep2) + + scalar = scatter( + node_scalar.view(-1), + data.batch, + dim=0, + reduce="sum" if self.extensive else "mean", + ) + irrep2 = scatter( + node_irrep2.view(-1, 5), + data.batch, + dim=0, + reduce="sum" if self.extensive else "mean", + ) + + # Note (@abhshkdz): If we have separate normalizers on the isotropic and + # anisotropic components (implemented in the trainer), combining the + # scalar and irrep2 predictions here would lead to the incorrect result. + # Instead, we should combine the predictions after the normalizers. + + return scalar.reshape(-1), irrep2 + + +@registry.register_model("rank2_symmetric_head") +class Rank2SymmetricTensorHead(nn.Module, HeadInterface): + """A rank 2 symmetric tensor prediction head. + + Attributes: + ouput_name: name of output prediction property (ie, stress) + sphharm_norm: layer normalization for spherical harmonic edge weights + xedge_layer_norm: embedding layer norm + block: rank 2 equivariant symmetric tensor block + """ + + def __init__( + self, + backbone: BackboneInterface, + output_name: str, + decompose: bool = False, + edge_level_mlp: bool = False, + num_mlp_layers: int = 2, + use_source_target_embedding: bool = False, + extensive: bool = False, + avg_num_nodes: int = 1.0, + default_norm_type: str = "layer_norm_sh", + ): + """ + Args: + backbone: Backbone model that the head is attached to + decompose: Whether to decompose the rank2 tensor into isotropic and anisotropic components + edge_level_mlp: If true apply MLP at edge level before pooling, otherwise use MLP at nodes after pooling + num_mlp_layers: number of MLP layers + use_source_target_embedding: Whether to use both source and target atom embeddings + extensive: Whether to do sum-pooling (extensive) vs mean pooling (intensive). + avg_num_nodes: Used only if extensive to divide prediction by avg num nodes. + """ + super().__init__() + self.output_name = output_name + self.decompose = decompose + self.use_source_target_embedding = use_source_target_embedding + self.avg_num_nodes = avg_num_nodes + + self.sphharm_norm = get_normalization_layer( + getattr(backbone, "norm_type", default_norm_type), + lmax=max(backbone.lmax_list), + num_channels=1, + ) + + if use_source_target_embedding: + r2_tensor_sphere_channels = backbone.sphere_channels * 2 + else: + r2_tensor_sphere_channels = backbone.sphere_channels + + self.xedge_layer_norm = nn.LayerNorm(r2_tensor_sphere_channels) + + if decompose: + self.block = Rank2DecompositionEdgeBlock( + emb_size=r2_tensor_sphere_channels, + num_layers=num_mlp_layers, + edge_level=edge_level_mlp, + extensive=extensive, + ) + else: + self.block = Rank2Block( + emb_size=r2_tensor_sphere_channels, + num_layers=num_mlp_layers, + edge_level=edge_level_mlp, + extensive=extensive, + ) + + # initialize weights + self.block.apply(partial(eqv2_init_weights, weight_init="uniform")) + + def forward( + self, data: dict[str, torch.Tensor] | torch.Tensor, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """ + Args: + data: data batch + emb: dictionary with embedding object and graph data + + Returns: dict of {output property name: predicted value} + """ + node_emb, graph = emb["node_embedding"], emb["graph"] + + sphharm_weights_edge = o3.spherical_harmonics( + torch.arange(0, node_emb.lmax_list[-1] + 1).tolist(), + graph.edge_distance_vec, + False, + ).detach() + + # layer norm because sphharm_weights_edge values become large and causes infs with amp + sphharm_weights_edge = self.sphharm_norm( + sphharm_weights_edge[:, :, None] + ).squeeze() + + if self.use_source_target_embedding: + x_source = node_emb.expand_edge(graph.edge_index[0]).embedding + x_target = node_emb.expand_edge(graph.edge_index[1]).embedding + x_edge = torch.cat((x_source, x_target), dim=2) + else: + x_edge = node_emb.expand_edge(graph.edge_index[1]).embedding + + x_edge = torch.einsum("abc, ab->ac", x_edge, sphharm_weights_edge) + + # layer norm because x_edge values become large and causes infs with amp + x_edge = self.xedge_layer_norm(x_edge) + + if self.decompose: + tensor_0, tensor_2 = self.block( + graph.edge_distance_vec, x_edge, graph.edge_index[1], data + ) + + if self.block.extensive: # legacy, may be interesting to try + tensor_0 = tensor_0 / self.avg_num_nodes + tensor_2 = tensor_2 / self.avg_num_nodes + + output = { + f"{self.output_name}_isotropic": tensor_0.unsqueeze(1), + f"{self.output_name}_anisotropic": tensor_2, + } + else: + out_tensor = self.block( + graph.edge_distance_vec, x_edge, graph.edge_index[1], data + ) + output = {self.output_name: out_tensor.reshape((-1, 3))} + + return output diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index c288f3f25..54b1992f4 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -524,7 +524,12 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: x_pt = x_pt.view(-1, self.sphere_channels_all) - return {"sphere_values": x_pt, "sphere_points": self.sphere_points} + return { + "sphere_values": x_pt, + "sphere_points": self.sphere_points, + "node_embedding": x, + "graph": graph, + } @registry.register_model("escn_energy_head") diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 26269c6da..662341bdc 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -655,7 +655,9 @@ def run_relaxations(self, split="val"): ) gather_results["chunk_idx"] = np.cumsum( [gather_results["chunk_idx"][i] for i in idx] - )[:-1] # np.split does not need last idx, assumes n-1:end + )[ + :-1 + ] # np.split does not need last idx, assumes n-1:end full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz" diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 46750f03b..0e3606b30 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -13,12 +13,19 @@ if TYPE_CHECKING: from pathlib import Path +from itertools import product +from random import choice + import numpy as np import pytest import requests import torch +from pymatgen.core import Structure +from pymatgen.core.periodic_table import Element from syrupy.extensions.amber import AmberSnapshotExtension +from fairchem.core.datasets import AseDBDataset, LMDBDatabase + if TYPE_CHECKING: from syrupy.types import SerializableData @@ -172,3 +179,49 @@ def tutorial_dataset_path(tmp_path_factory) -> Path: tarfile.open(fileobj=response.raw, mode="r|gz").extractall(path=tmpdir) return tmpdir + + +@pytest.fixture(scope="session") +def dummy_element_refs(): + # create some dummy elemental energies from ionic radii (ignore deuterium and tritium included in pmg) + return np.concatenate( + [[0], [e.average_ionic_radius for e in Element if e.name not in ("D", "T")]] + ) + + +@pytest.fixture(scope="session") +def dummy_binary_dataset_path(tmpdir_factory, dummy_element_refs): + # a dummy dataset with binaries with energy that depends on composition only plus noise + all_binaries = list(product(list(Element), repeat=2)) + rng = np.random.default_rng(seed=0) + + tmpdir = tmpdir_factory.mktemp("dataset") + with LMDBDatabase(tmpdir / "dummy.aselmdb") as db: + for _ in range(1000): + elements = choice(all_binaries) + structure = Structure.from_prototype("cscl", species=elements, a=2.0) + energy = ( + sum(e.average_ionic_radius for e in elements) + + 0.05 * rng.random() * dummy_element_refs.mean() + ) + atoms = structure.to_ase_atoms() + db.write( + atoms, + data={ + "energy": energy, + "forces": rng.random((2, 3)), + "stress": rng.random((3, 3)), + }, + ) + + return tmpdir / "dummy.aselmdb" + + +@pytest.fixture(scope="session") +def dummy_binary_dataset(dummy_binary_dataset_path): + return AseDBDataset( + config={ + "src": str(dummy_binary_dataset_path), + "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + } + ) diff --git a/tests/core/e2e/conftest.py b/tests/core/e2e/conftest.py new file mode 100644 index 000000000..fc1c6a432 --- /dev/null +++ b/tests/core/e2e/conftest.py @@ -0,0 +1,45 @@ +import pytest + +from pathlib import Path + + +@pytest.fixture() +def configs(): + return { + "scn": Path("tests/core/models/test_configs/test_scn.yml"), + "escn": Path("tests/core/models/test_configs/test_escn.yml"), + "escn_hydra": Path("tests/core/models/test_configs/test_escn_hydra.yml"), + "schnet": Path("tests/core/models/test_configs/test_schnet.yml"), + "gemnet_dt": Path("tests/core/models/test_configs/test_gemnet_dt.yml"), + "gemnet_dt_hydra": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra.yml" + ), + "gemnet_dt_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml" + ), + "gemnet_oc": Path("tests/core/models/test_configs/test_gemnet_oc.yml"), + "gemnet_oc_hydra": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra.yml" + ), + "gemnet_oc_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml" + ), + "dimenet++": Path("tests/core/models/test_configs/test_dpp.yml"), + "dimenet++_hydra": Path("tests/core/models/test_configs/test_dpp_hydra.yml"), + "painn": Path("tests/core/models/test_configs/test_painn.yml"), + "painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"), + "equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"), + "equiformer_v2_hydra": Path( + "tests/core/models/test_configs/test_equiformerv2_hydra.yml" + ), + } + + +@pytest.fixture() +def tutorial_train_src(tutorial_dataset_path): + return tutorial_dataset_path / "s2ef/train_100" + + +@pytest.fixture() +def tutorial_val_src(tutorial_dataset_path): + return tutorial_dataset_path / "s2ef/val_20" diff --git a/tests/core/e2e/test_e2e_commons.py b/tests/core/e2e/test_e2e_commons.py index a171a9689..ff3ea3634 100644 --- a/tests/core/e2e/test_e2e_commons.py +++ b/tests/core/e2e/test_e2e_commons.py @@ -92,6 +92,7 @@ def merge_dictionary(d, u): d[k] = v return d + def _run_main( rundir, input_yaml, diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 695fb537d..6b83749c0 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -16,48 +16,6 @@ setup_logging() -@pytest.fixture() -def configs(): - return { - "scn": Path("tests/core/models/test_configs/test_scn.yml"), - "escn": Path("tests/core/models/test_configs/test_escn.yml"), - "escn_hydra": Path("tests/core/models/test_configs/test_escn_hydra.yml"), - "schnet": Path("tests/core/models/test_configs/test_schnet.yml"), - "gemnet_dt": Path("tests/core/models/test_configs/test_gemnet_dt.yml"), - "gemnet_dt_hydra": Path( - "tests/core/models/test_configs/test_gemnet_dt_hydra.yml" - ), - "gemnet_dt_hydra_grad": Path( - "tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml" - ), - "gemnet_oc": Path("tests/core/models/test_configs/test_gemnet_oc.yml"), - "gemnet_oc_hydra": Path( - "tests/core/models/test_configs/test_gemnet_oc_hydra.yml" - ), - "gemnet_oc_hydra_grad": Path( - "tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml" - ), - "dimenet++": Path("tests/core/models/test_configs/test_dpp.yml"), - "dimenet++_hydra": Path("tests/core/models/test_configs/test_dpp_hydra.yml"), - "painn": Path("tests/core/models/test_configs/test_painn.yml"), - "painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"), - "equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"), - "equiformer_v2_hydra": Path( - "tests/core/models/test_configs/test_equiformerv2_hydra.yml" - ), - } - - -@pytest.fixture() -def tutorial_train_src(tutorial_dataset_path): - return tutorial_dataset_path / "s2ef/train_100" - - -@pytest.fixture() -def tutorial_val_src(tutorial_dataset_path): - return tutorial_dataset_path / "s2ef/val_20" - - """ These tests are intended to be as quick as possible and test only that the network is runnable and outputs training+validation to tensorboard output These should catch errors such as shape mismatches or otherways to code wise break a network diff --git a/tests/core/e2e/test_s2efs.py b/tests/core/e2e/test_s2efs.py new file mode 100644 index 000000000..94b0862ed --- /dev/null +++ b/tests/core/e2e/test_s2efs.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +from test_e2e_commons import _run_main + + +# TODO add GemNet! +@pytest.mark.parametrize( + ("model_name", "ddp"), + [ + ("equiformer_v2_hydra", False), + ("escn_hydra", False), + ("equiformer_v2_hydra", True), + ("escn_hydra", True), + ], +) +def test_smoke_s2efs_predict( + model_name, ddp, configs, dummy_binary_dataset_path, tmpdir +): + # train an s2ef model just to have one + input_yaml = configs[model_name] + train_rundir = tmpdir / "train" + train_rundir.mkdir() + checkpoint_path = str(train_rundir / "checkpoint.pt") + training_predictions_filename = str(train_rundir / "train_predictions.npz") + + updates = { + "task": {"strict_load": False}, + "model": { + "backbone": {"max_num_elements": 118 + 1}, + "heads": { + "stress": { + "module": "rank2_symmetric_head", + "output_name": "stress", + "use_source_target_embedding": True, + } + }, + }, + "loss_functions": [ + {"energy": {"fn": "mae", "coefficient": 2}}, + {"forces": {"fn": "l2mae", "coefficient": 100}}, + {"stress": {"fn": "mae", "coefficient": 100}}, + ], + "outputs": {"stress": {"level": "system", "irrep_dim": 2}}, + "evaluation_metrics": {"metrics": {"stress": ["mae"]}}, + "dataset": { + "train": { + "src": str(dummy_binary_dataset_path), + "format": "ase_db", + "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + "sample_n": 20, + }, + "val": { + "src": str(dummy_binary_dataset_path), + "format": "ase_db", + "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + "sample_n": 5, + }, + "test": { + "src": str(dummy_binary_dataset_path), + "format": "ase_db", + "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + "sample_n": 5, + }, + }, + } + + acc = _run_main( + rundir=str(train_rundir), + input_yaml=input_yaml, + update_dict_with={ + "optim": { + "max_epochs": 2, + "eval_every": 4, + "batch_size": 5, + "num_workers": 0 if ddp else 2, + }, + **updates, + }, + save_checkpoint_to=checkpoint_path, + save_predictions_to=training_predictions_filename, + world_size=1 if ddp else 0, + ) + assert "train/energy_mae" in acc.Tags()["scalars"] + assert "val/energy_mae" in acc.Tags()["scalars"] + + # now load a checkpoint with an added stress head + # second load the checkpoint and predict + predictions_rundir = Path(tmpdir) / "predict" + predictions_rundir.mkdir() + predictions_filename = str(predictions_rundir / "predictions.npz") + _run_main( + rundir=str(predictions_rundir), + input_yaml=input_yaml, + update_dict_with={ + "task": {"strict_load": False}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, + **updates, + }, + update_run_args_with={ + "mode": "predict", + "checkpoint": checkpoint_path, + }, + save_predictions_to=predictions_filename, + ) + predictions = np.load(training_predictions_filename) + + for output in updates["outputs"]: + assert output in predictions + + assert predictions["energy"].shape == (5, 1) + assert predictions["forces"].shape == (10, 3) + assert predictions["stress"].shape == (5, 9) diff --git a/tests/core/models/__snapshots__/test_equiformer_v2.ambr b/tests/core/models/__snapshots__/test_equiformer_v2.ambr index 5ddf7f2be..03be8ebda 100644 --- a/tests/core/models/__snapshots__/test_equiformer_v2.ambr +++ b/tests/core/models/__snapshots__/test_equiformer_v2.ambr @@ -56,7 +56,7 @@ # --- # name: TestEquiformerV2.test_gp.1 Approx( - array([-0.03269595], dtype=float32), + array([0.12408739], dtype=float32), rtol=0.001, atol=0.001 ) @@ -69,7 +69,7 @@ # --- # name: TestEquiformerV2.test_gp.3 Approx( - array([ 0.00208857, -0.00017979, -0.0028318 ], dtype=float32), + array([ 1.4928661e-03, -7.4134863e-05, 2.9909245e-03], dtype=float32), rtol=0.001, atol=0.001 ) diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml index 0f72570fd..1852799f5 100644 --- a/tests/core/models/test_configs/test_equiformerv2_hydra.yml +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -29,7 +29,7 @@ evaluation_metrics: misc: - energy_forces_within_threshold primary_metric: forces_mae - + logger: name: tensorboard diff --git a/tests/core/models/test_rank2_head.py b/tests/core/models/test_rank2_head.py new file mode 100644 index 000000000..c00667806 --- /dev/null +++ b/tests/core/models/test_rank2_head.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from itertools import product + +import pytest +import torch +from ase.build import bulk + +from fairchem.core.common.utils import cg_change_mat, irreps_sum +from fairchem.core.datasets import data_list_collater +from fairchem.core.models.equiformer_v2.equiformer_v2 import EquiformerV2Backbone +from fairchem.core.models.equiformer_v2.prediction_heads import Rank2SymmetricTensorHead +from fairchem.core.preprocessing import AtomsToGraphs + + +def _reshape_tensor(out, batch_size=1): + tensor = torch.zeros((batch_size, irreps_sum(2)), requires_grad=False) + tensor[:, max(0, irreps_sum(1)) : irreps_sum(2)] = out.view(batch_size, -1) + tensor = torch.einsum("ba, cb->ca", cg_change_mat(2), tensor) + return tensor.view(3, 3) + + +@pytest.fixture(scope="session") +def batch(): + a2g = AtomsToGraphs(r_pbc=True) + return data_list_collater([a2g.convert(bulk("ZnFe", "wurtzite", a=2.0))]) + + +@pytest.mark.parametrize( + ("decompose", "edge_level_mlp", "use_source_target_embedding", "extensive"), + list(product((True, False), repeat=4)), +) +def test_rank2_head( + batch, decompose, edge_level_mlp, use_source_target_embedding, extensive +): + torch.manual_seed(100) # fix network initialization + backbone = EquiformerV2Backbone( + num_layers=2, + sphere_channels=8, + attn_hidden_channels=8, + num_sphere_samples=8, + edge_channels=8, + ) + head = Rank2SymmetricTensorHead( + backbone=backbone, + output_name="out", + decompose=decompose, + edge_level_mlp=edge_level_mlp, + use_source_target_embedding=use_source_target_embedding, + extensive=extensive, + ) + + r2_out = head(batch, backbone(batch)) + + if decompose is True: + assert "out_isotropic" in r2_out + assert "out_anisotropic" in r2_out + # isotropic must be scalar + assert r2_out["out_isotropic"].shape[1] == 1 + tensor = _reshape_tensor(r2_out["out_isotropic"]) + # anisotropic must be traceless + assert torch.diagonal(tensor).sum().item() == pytest.approx(0.0, abs=2e-8) + else: + assert "out" in r2_out + tensor = r2_out["out"].view(3, 3) + + # all tensors must be symmetric + assert torch.allclose(tensor, tensor.transpose(0, 1)) diff --git a/tests/core/modules/conftest.py b/tests/core/modules/conftest.py index 1b1e4ab7e..0a210639d 100644 --- a/tests/core/modules/conftest.py +++ b/tests/core/modules/conftest.py @@ -1,48 +1,8 @@ -from itertools import product -from random import choice -import pytest -import numpy as np -from pymatgen.core.periodic_table import Element -from pymatgen.core import Structure - -from fairchem.core.datasets import LMDBDatabase, AseDBDataset - +from __future__ import annotations -@pytest.fixture(scope="session") -def dummy_element_refs(): - # create some dummy elemental energies from ionic radii (ignore deuterium and tritium included in pmg) - return np.concatenate( - [[0], [e.average_ionic_radius for e in Element if e.name not in ("D", "T")]] - ) +import pytest @pytest.fixture(scope="session") def max_num_elements(dummy_element_refs): return len(dummy_element_refs) - 1 - - -@pytest.fixture(scope="session") -def dummy_binary_dataset(tmpdir_factory, dummy_element_refs): - # a dummy dataset with binaries with energy that depends on composition only plus noise - all_binaries = list(product(list(Element), repeat=2)) - rng = np.random.default_rng(seed=0) - - tmpdir = tmpdir_factory.mktemp("dataset") - with LMDBDatabase(tmpdir / "dummy.aselmdb") as db: - for _ in range(1000): - elements = choice(all_binaries) - structure = Structure.from_prototype("cscl", species=elements, a=2.0) - energy = ( - sum(e.average_ionic_radius for e in elements) - + 0.05 * rng.random() * dummy_element_refs.mean() - ) - atoms = structure.to_ase_atoms() - db.write(atoms, data={"energy": energy, "forces": rng.random((2, 3))}) - - dataset = AseDBDataset( - config={ - "src": str(tmpdir / "dummy.aselmdb"), - "a2g_args": {"r_data_keys": ["energy", "forces"]}, - } - ) - return dataset