diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 98e21a77f..b20f7e250 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -1,28 +1,30 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + from __future__ import annotations import contextlib import logging -import math +import typing from functools import partial import torch import torch.nn as nn +from typing_extensions import deprecated from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad from fairchem.core.models.base import ( GraphModelMixin, - HeadInterface, ) +from fairchem.core.models.equiformer_v2.heads import EqV2ScalarHead, EqV2VectorHead from fairchem.core.models.scn.smearing import GaussianSmearing -with contextlib.suppress(ImportError): - pass - - -import typing - from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding @@ -34,7 +36,6 @@ get_normalization_layer, ) from .module_list import ModuleListInfo -from .radial_function import RadialFunction from .so3 import ( CoefficientMappingModule, SO3_Embedding, @@ -43,41 +44,43 @@ SO3_Rotation, ) from .transformer_block import ( - FeedForwardNetwork, - SO2EquivariantGraphAttention, TransBlockV2, ) +from .weight_initialization import eqv2_init_weights + +with contextlib.suppress(ImportError): + pass if typing.TYPE_CHECKING: from torch_geometric.data.batch import Batch - from fairchem.core.models.base import GraphData - # Statistics of IS2RE 100K _AVG_NUM_NODES = 77.81317 _AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100 -def eqv2_init_weights(m, weight_init): - if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): - if m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - if weight_init == "normal": - std = 1 / math.sqrt(m.in_features) - torch.nn.init.normal_(m.weight, 0, std) - elif isinstance(m, torch.nn.LayerNorm): - torch.nn.init.constant_(m.bias, 0) - torch.nn.init.constant_(m.weight, 1.0) - elif isinstance(m, RadialFunction): - m.apply(eqv2_uniform_init_linear_weights) +@deprecated( + "equiformer_v2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head)" +) +@registry.register_model("equiformer_v2_force_head") +class EquiformerV2ForceHead(EqV2VectorHead): + def __init__(self, backbone): + logging.warning( + "equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head)" + ) + super().__init__(backbone) -def eqv2_uniform_init_linear_weights(m): - if isinstance(m, torch.nn.Linear): - if m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - std = 1 / math.sqrt(m.in_features) - torch.nn.init.uniform_(m.weight, -std, std) +@deprecated( + "equiformer_v2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead)" +) +@registry.register_model("equiformer_v2_energy_head") +class EquiformerV2EnergyHead(EqV2ScalarHead): + def __init__(self, backbone, reduce: str = "sum"): + logging.warning( + "equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead)" + ) + super().__init__(backbone, reduce=reduce) @registry.register_model("equiformer_v2_backbone") @@ -606,102 +609,3 @@ def no_weight_decay(self) -> set: no_wd_list.append(global_parameter_name) return set(no_wd_list) - - -@registry.register_model("equiformer_v2_energy_head") -class EquiformerV2EnergyHead(nn.Module, HeadInterface): - def __init__(self, backbone, reduce: str = "sum"): - super().__init__() - self.reduce = reduce - self.avg_num_nodes = backbone.avg_num_nodes - self.energy_block = FeedForwardNetwork( - backbone.sphere_channels, - backbone.ffn_hidden_channels, - 1, - backbone.lmax_list, - backbone.mmax_list, - backbone.SO3_grid, - backbone.ffn_activation, - backbone.use_gate_act, - backbone.use_grid_mlp, - backbone.use_sep_s2_act, - ) - self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) - - def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): - node_energy = self.energy_block(emb["node_embedding"]) - node_energy = node_energy.embedding.narrow(1, 0, 1) - if gp_utils.initialized(): - node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0) - energy = torch.zeros( - len(data.natoms), - device=node_energy.device, - dtype=node_energy.dtype, - ) - - energy.index_add_(0, data.batch, node_energy.view(-1)) - if self.reduce == "sum": - return {"energy": energy / self.avg_num_nodes} - elif self.reduce == "mean": - return {"energy": energy / data.natoms} - else: - raise ValueError( - f"reduce can only be sum or mean, user provided: {self.reduce}" - ) - - -@registry.register_model("equiformer_v2_force_head") -class EquiformerV2ForceHead(nn.Module, HeadInterface): - def __init__(self, backbone): - super().__init__() - - self.activation_checkpoint = backbone.activation_checkpoint - self.force_block = SO2EquivariantGraphAttention( - backbone.sphere_channels, - backbone.attn_hidden_channels, - backbone.num_heads, - backbone.attn_alpha_channels, - backbone.attn_value_channels, - 1, - backbone.lmax_list, - backbone.mmax_list, - backbone.SO3_rotation, - backbone.mappingReduced, - backbone.SO3_grid, - backbone.max_num_elements, - backbone.edge_channels_list, - backbone.block_use_atom_edge_embedding, - backbone.use_m_share_rad, - backbone.attn_activation, - backbone.use_s2_act_attn, - backbone.use_attn_renorm, - backbone.use_gate_act, - backbone.use_sep_s2_act, - alpha_drop=0.0, - ) - self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) - - def forward(self, data: Batch, emb: dict[str, torch.Tensor]): - if self.activation_checkpoint: - forces = torch.utils.checkpoint.checkpoint( - self.force_block, - emb["node_embedding"], - emb["graph"].atomic_numbers_full, - emb["graph"].edge_distance, - emb["graph"].edge_index, - emb["graph"].node_offset, - use_reentrant=not self.training, - ) - else: - forces = self.force_block( - emb["node_embedding"], - emb["graph"].atomic_numbers_full, - emb["graph"].edge_distance, - emb["graph"].edge_index, - node_offset=emb["graph"].node_offset, - ) - forces = forces.embedding.narrow(1, 1, 3) - forces = forces.view(-1, 3).contiguous() - if gp_utils.initialized(): - forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) - return {"forces": forces} diff --git a/src/fairchem/core/models/equiformer_v2/heads/__init__.py b/src/fairchem/core/models/equiformer_v2/heads/__init__.py new file mode 100644 index 000000000..100fcea1a --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/heads/__init__.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from .rank2 import Rank2SymmetricTensorHead +from .scalar import EqV2ScalarHead +from .vector import EqV2VectorHead + +__all__ = ["EqV2ScalarHead", "EqV2VectorHead", "Rank2SymmetricTensorHead"] diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/heads/rank2.py similarity index 99% rename from src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py rename to src/fairchem/core/models/equiformer_v2/heads/rank2.py index 2bbf42eaa..6a716cc7d 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/heads/rank2.py @@ -16,8 +16,8 @@ from fairchem.core.common.registry import registry from fairchem.core.models.base import BackboneInterface, HeadInterface -from fairchem.core.models.equiformer_v2.equiformer_v2 import eqv2_init_weights from fairchem.core.models.equiformer_v2.layer_norm import get_normalization_layer +from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights class Rank2Block(nn.Module): @@ -238,7 +238,7 @@ class Rank2SymmetricTensorHead(nn.Module, HeadInterface): def __init__( self, backbone: BackboneInterface, - output_name: str, + output_name: str = "stress", decompose: bool = False, edge_level_mlp: bool = False, num_mlp_layers: int = 2, diff --git a/src/fairchem/core/models/equiformer_v2/heads/scalar.py b/src/fairchem/core/models/equiformer_v2/heads/scalar.py new file mode 100644 index 000000000..b3936ba15 --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/heads/scalar.py @@ -0,0 +1,66 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from fairchem.core.common import gp_utils +from fairchem.core.common.registry import registry +from fairchem.core.models.base import GraphData, HeadInterface +from fairchem.core.models.equiformer_v2.transformer_block import FeedForwardNetwork +from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights + +if TYPE_CHECKING: + from torch_geometric.data import Batch + + +@registry.register_model("equiformerV2_scalar_head") +class EqV2ScalarHead(nn.Module, HeadInterface): + def __init__(self, backbone, output_name: str = "energy", reduce: str = "sum"): + super().__init__() + self.output_name = output_name + self.reduce = reduce + self.avg_num_nodes = backbone.avg_num_nodes + self.energy_block = FeedForwardNetwork( + backbone.sphere_channels, + backbone.ffn_hidden_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_grid, + backbone.ffn_activation, + backbone.use_gate_act, + backbone.use_grid_mlp, + backbone.use_sep_s2_act, + ) + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): + node_output = self.energy_block(emb["node_embedding"]) + node_output = node_output.embedding.narrow(1, 0, 1) + if gp_utils.initialized(): + node_output = gp_utils.gather_from_model_parallel_region(node_output, dim=0) + output = torch.zeros( + len(data.natoms), + device=node_output.device, + dtype=node_output.dtype, + ) + + output.index_add_(0, data.batch, node_output.view(-1)) + if self.reduce == "sum": + return {self.output_name: output / self.avg_num_nodes} + elif self.reduce == "mean": + return {self.output_name: output / data.natoms} + else: + raise ValueError( + f"reduce can only be sum or mean, user provided: {self.reduce}" + ) diff --git a/src/fairchem/core/models/equiformer_v2/heads/vector.py b/src/fairchem/core/models/equiformer_v2/heads/vector.py new file mode 100644 index 000000000..49bc27bdc --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/heads/vector.py @@ -0,0 +1,84 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from fairchem.core.common import gp_utils +from fairchem.core.common.registry import registry +from fairchem.core.models.base import HeadInterface +from fairchem.core.models.equiformer_v2.transformer_block import ( + SO2EquivariantGraphAttention, +) +from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights + +if TYPE_CHECKING: + from torch_geometric.data import Batch + + from fairchem.core.models.base import BackboneInterface + + +@registry.register_model("equiformerV2_vector_head") +class EqV2VectorHead(nn.Module, HeadInterface): + def __init__(self, backbone: BackboneInterface, output_name: str = "forces"): + super().__init__() + self.output_name = output_name + self.activation_checkpoint = backbone.activation_checkpoint + self.force_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor]): + if self.activation_checkpoint: + output = torch.utils.checkpoint.checkpoint( + self.force_block, + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + emb["graph"].node_offset, + use_reentrant=not self.training, + ) + else: + output = self.force_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + output = output.embedding.narrow(1, 1, 3) + output = output.view(-1, 3).contiguous() + if gp_utils.initialized(): + output = gp_utils.gather_from_model_parallel_region(output, dim=0) + return {self.output_name: output} diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py deleted file mode 100644 index 7542c0d13..000000000 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .rank2 import Rank2SymmetricTensorHead - -__all__ = ["Rank2SymmetricTensorHead"] diff --git a/src/fairchem/core/models/equiformer_v2/weight_initialization.py b/src/fairchem/core/models/equiformer_v2/weight_initialization.py new file mode 100644 index 000000000..42c282eab --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/weight_initialization.py @@ -0,0 +1,37 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import math + +import torch + +from fairchem.core.models.equiformer_v2.radial_function import RadialFunction +from fairchem.core.models.equiformer_v2.so3 import SO3_LinearV2 + + +def eqv2_init_weights(m, weight_init): + if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + if weight_init == "normal": + std = 1 / math.sqrt(m.in_features) + torch.nn.init.normal_(m.weight, 0, std) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + elif isinstance(m, RadialFunction): + m.apply(eqv2_uniform_init_linear_weights) + + +def eqv2_uniform_init_linear_weights(m): + if isinstance(m, torch.nn.Linear): + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + std = 1 / math.sqrt(m.in_features) + torch.nn.init.uniform_(m.weight, -std, std) diff --git a/tests/core/models/test_rank2_head.py b/tests/core/models/test_rank2_head.py index c00667806..5ff2166db 100644 --- a/tests/core/models/test_rank2_head.py +++ b/tests/core/models/test_rank2_head.py @@ -9,7 +9,7 @@ from fairchem.core.common.utils import cg_change_mat, irreps_sum from fairchem.core.datasets import data_list_collater from fairchem.core.models.equiformer_v2.equiformer_v2 import EquiformerV2Backbone -from fairchem.core.models.equiformer_v2.prediction_heads import Rank2SymmetricTensorHead +from fairchem.core.models.equiformer_v2.heads import Rank2SymmetricTensorHead from fairchem.core.preprocessing import AtomsToGraphs