Skip to content

Commit

Permalink
Rename hydra heads (#903)
Browse files Browse the repository at this point in the history
* refactor eqV2 heads

* refactor init_weights

* add output_name attribute to eqV2 heads

* fix deprecated registry names

* remove debug breakpoint

* add default name for rank2 head
  • Loading branch information
lbluque authored Nov 8, 2024
1 parent 6329e92 commit 834fbd6
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 138 deletions.
164 changes: 34 additions & 130 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

import contextlib
import logging
import math
import typing
from functools import partial

import torch
import torch.nn as nn
from typing_extensions import deprecated

from fairchem.core.common import gp_utils
from fairchem.core.common.registry import registry
from fairchem.core.common.utils import conditional_grad
from fairchem.core.models.base import (
GraphModelMixin,
HeadInterface,
)
from fairchem.core.models.equiformer_v2.heads import EqV2ScalarHead, EqV2VectorHead
from fairchem.core.models.scn.smearing import GaussianSmearing

with contextlib.suppress(ImportError):
pass


import typing

from .edge_rot_mat import init_edge_rot_mat
from .gaussian_rbf import GaussianRadialBasisLayer
from .input_block import EdgeDegreeEmbedding
Expand All @@ -34,7 +36,6 @@
get_normalization_layer,
)
from .module_list import ModuleListInfo
from .radial_function import RadialFunction
from .so3 import (
CoefficientMappingModule,
SO3_Embedding,
Expand All @@ -43,41 +44,43 @@
SO3_Rotation,
)
from .transformer_block import (
FeedForwardNetwork,
SO2EquivariantGraphAttention,
TransBlockV2,
)
from .weight_initialization import eqv2_init_weights

with contextlib.suppress(ImportError):
pass

if typing.TYPE_CHECKING:
from torch_geometric.data.batch import Batch

from fairchem.core.models.base import GraphData

# Statistics of IS2RE 100K
_AVG_NUM_NODES = 77.81317
_AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100


def eqv2_init_weights(m, weight_init):
if isinstance(m, (torch.nn.Linear, SO3_LinearV2)):
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
if weight_init == "normal":
std = 1 / math.sqrt(m.in_features)
torch.nn.init.normal_(m.weight, 0, std)
elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.weight, 1.0)
elif isinstance(m, RadialFunction):
m.apply(eqv2_uniform_init_linear_weights)
@deprecated(
"equiformer_v2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head)"
)
@registry.register_model("equiformer_v2_force_head")
class EquiformerV2ForceHead(EqV2VectorHead):
def __init__(self, backbone):
logging.warning(
"equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head)"
)
super().__init__(backbone)


def eqv2_uniform_init_linear_weights(m):
if isinstance(m, torch.nn.Linear):
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
std = 1 / math.sqrt(m.in_features)
torch.nn.init.uniform_(m.weight, -std, std)
@deprecated(
"equiformer_v2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead)"
)
@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2EnergyHead(EqV2ScalarHead):
def __init__(self, backbone, reduce: str = "sum"):
logging.warning(
"equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead)"
)
super().__init__(backbone, reduce=reduce)


@registry.register_model("equiformer_v2_backbone")
Expand Down Expand Up @@ -606,102 +609,3 @@ def no_weight_decay(self) -> set:
no_wd_list.append(global_parameter_name)

return set(no_wd_list)


@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, reduce: str = "sum"):
super().__init__()
self.reduce = reduce
self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
backbone.sphere_channels,
backbone.ffn_hidden_channels,
1,
backbone.lmax_list,
backbone.mmax_list,
backbone.SO3_grid,
backbone.ffn_activation,
backbone.use_gate_act,
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))

def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
node_energy = self.energy_block(emb["node_embedding"])
node_energy = node_energy.embedding.narrow(1, 0, 1)
if gp_utils.initialized():
node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0)
energy = torch.zeros(
len(data.natoms),
device=node_energy.device,
dtype=node_energy.dtype,
)

energy.index_add_(0, data.batch, node_energy.view(-1))
if self.reduce == "sum":
return {"energy": energy / self.avg_num_nodes}
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("equiformer_v2_force_head")
class EquiformerV2ForceHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

self.activation_checkpoint = backbone.activation_checkpoint
self.force_block = SO2EquivariantGraphAttention(
backbone.sphere_channels,
backbone.attn_hidden_channels,
backbone.num_heads,
backbone.attn_alpha_channels,
backbone.attn_value_channels,
1,
backbone.lmax_list,
backbone.mmax_list,
backbone.SO3_rotation,
backbone.mappingReduced,
backbone.SO3_grid,
backbone.max_num_elements,
backbone.edge_channels_list,
backbone.block_use_atom_edge_embedding,
backbone.use_m_share_rad,
backbone.attn_activation,
backbone.use_s2_act_attn,
backbone.use_attn_renorm,
backbone.use_gate_act,
backbone.use_sep_s2_act,
alpha_drop=0.0,
)
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
if self.activation_checkpoint:
forces = torch.utils.checkpoint.checkpoint(
self.force_block,
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
emb["graph"].node_offset,
use_reentrant=not self.training,
)
else:
forces = self.force_block(
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
node_offset=emb["graph"].node_offset,
)
forces = forces.embedding.narrow(1, 1, 3)
forces = forces.view(-1, 3).contiguous()
if gp_utils.initialized():
forces = gp_utils.gather_from_model_parallel_region(forces, dim=0)
return {"forces": forces}
7 changes: 7 additions & 0 deletions src/fairchem/core/models/equiformer_v2/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations

from .rank2 import Rank2SymmetricTensorHead
from .scalar import EqV2ScalarHead
from .vector import EqV2VectorHead

__all__ = ["EqV2ScalarHead", "EqV2VectorHead", "Rank2SymmetricTensorHead"]
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from fairchem.core.common.registry import registry
from fairchem.core.models.base import BackboneInterface, HeadInterface
from fairchem.core.models.equiformer_v2.equiformer_v2 import eqv2_init_weights
from fairchem.core.models.equiformer_v2.layer_norm import get_normalization_layer
from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights


class Rank2Block(nn.Module):
Expand Down Expand Up @@ -238,7 +238,7 @@ class Rank2SymmetricTensorHead(nn.Module, HeadInterface):
def __init__(
self,
backbone: BackboneInterface,
output_name: str,
output_name: str = "stress",
decompose: bool = False,
edge_level_mlp: bool = False,
num_mlp_layers: int = 2,
Expand Down
66 changes: 66 additions & 0 deletions src/fairchem/core/models/equiformer_v2/heads/scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import torch
from torch import nn

from fairchem.core.common import gp_utils
from fairchem.core.common.registry import registry
from fairchem.core.models.base import GraphData, HeadInterface
from fairchem.core.models.equiformer_v2.transformer_block import FeedForwardNetwork
from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights

if TYPE_CHECKING:
from torch_geometric.data import Batch


@registry.register_model("equiformerV2_scalar_head")
class EqV2ScalarHead(nn.Module, HeadInterface):
def __init__(self, backbone, output_name: str = "energy", reduce: str = "sum"):
super().__init__()
self.output_name = output_name
self.reduce = reduce
self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
backbone.sphere_channels,
backbone.ffn_hidden_channels,
1,
backbone.lmax_list,
backbone.mmax_list,
backbone.SO3_grid,
backbone.ffn_activation,
backbone.use_gate_act,
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))

def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
node_output = self.energy_block(emb["node_embedding"])
node_output = node_output.embedding.narrow(1, 0, 1)
if gp_utils.initialized():
node_output = gp_utils.gather_from_model_parallel_region(node_output, dim=0)
output = torch.zeros(
len(data.natoms),
device=node_output.device,
dtype=node_output.dtype,
)

output.index_add_(0, data.batch, node_output.view(-1))
if self.reduce == "sum":
return {self.output_name: output / self.avg_num_nodes}
elif self.reduce == "mean":
return {self.output_name: output / data.natoms}
else:
raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)
84 changes: 84 additions & 0 deletions src/fairchem/core/models/equiformer_v2/heads/vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import torch
from torch import nn

from fairchem.core.common import gp_utils
from fairchem.core.common.registry import registry
from fairchem.core.models.base import HeadInterface
from fairchem.core.models.equiformer_v2.transformer_block import (
SO2EquivariantGraphAttention,
)
from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights

if TYPE_CHECKING:
from torch_geometric.data import Batch

from fairchem.core.models.base import BackboneInterface


@registry.register_model("equiformerV2_vector_head")
class EqV2VectorHead(nn.Module, HeadInterface):
def __init__(self, backbone: BackboneInterface, output_name: str = "forces"):
super().__init__()
self.output_name = output_name
self.activation_checkpoint = backbone.activation_checkpoint
self.force_block = SO2EquivariantGraphAttention(
backbone.sphere_channels,
backbone.attn_hidden_channels,
backbone.num_heads,
backbone.attn_alpha_channels,
backbone.attn_value_channels,
1,
backbone.lmax_list,
backbone.mmax_list,
backbone.SO3_rotation,
backbone.mappingReduced,
backbone.SO3_grid,
backbone.max_num_elements,
backbone.edge_channels_list,
backbone.block_use_atom_edge_embedding,
backbone.use_m_share_rad,
backbone.attn_activation,
backbone.use_s2_act_attn,
backbone.use_attn_renorm,
backbone.use_gate_act,
backbone.use_sep_s2_act,
alpha_drop=0.0,
)
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
if self.activation_checkpoint:
output = torch.utils.checkpoint.checkpoint(
self.force_block,
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
emb["graph"].node_offset,
use_reentrant=not self.training,
)
else:
output = self.force_block(
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
node_offset=emb["graph"].node_offset,
)
output = output.embedding.narrow(1, 1, 3)
output = output.view(-1, 3).contiguous()
if gp_utils.initialized():
output = gp_utils.gather_from_model_parallel_region(output, dim=0)
return {self.output_name: output}
Loading

0 comments on commit 834fbd6

Please sign in to comment.