Skip to content

Commit

Permalink
Rank2 tensor head (#792)
Browse files Browse the repository at this point in the history
* rank 2 tensor head

* fix rank2 head and add to e2e test

* small fixes

* keep hydra graphmixin

* add rank2 head tests

* test fixes

* fix tests; move init_weight out of equiformer; add amp property to heads+hydra

* add amp to heads and hydra

* fix import

* update snapshot; fix test seed and change tolerance

---------

Co-authored-by: Misko <[email protected]>
  • Loading branch information
lbluque and misko authored Aug 17, 2024
1 parent 2078e48 commit e380c66
Show file tree
Hide file tree
Showing 15 changed files with 693 additions and 121 deletions.
15 changes: 14 additions & 1 deletion src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def no_weight_decay(self) -> list:


class HeadInterface(metaclass=ABCMeta):
@property
def use_amp(self):
return False

@abstractmethod
def forward(
self, data: Batch, emb: dict[str, torch.Tensor]
Expand Down Expand Up @@ -249,6 +253,7 @@ def __init__(
):
super().__init__()
self.otf_graph = otf_graph
self.device = "cpu"
# make a copy so we don't modify the original config
backbone = copy.deepcopy(backbone)
heads = copy.deepcopy(heads)
Expand Down Expand Up @@ -279,12 +284,20 @@ def __init__(

self.output_heads = torch.nn.ModuleDict(self.output_heads)

def to(self, *args, **kwargs):
if "device" in kwargs:
self.device = kwargs["device"]
return super().to(*args, **kwargs)

def forward(self, data: Batch):
emb = self.backbone(data)
# Predict all output properties for all structures in the batch for now.
out = {}
for k in self.output_heads:
out.update(self.output_heads[k](data, emb))
with torch.autocast(
device_type=self.device, enabled=self.output_heads[k].use_amp
):
out.update(self.output_heads[k](data, emb))

return out

Expand Down
57 changes: 26 additions & 31 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import logging
import math
from functools import partial

import torch
import torch.nn as nn
Expand Down Expand Up @@ -54,6 +55,28 @@
_AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100


def eqv2_init_weights(m, weight_init):
if isinstance(m, (torch.nn.Linear, SO3_LinearV2)):
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
if weight_init == "normal":
std = 1 / math.sqrt(m.in_features)
torch.nn.init.normal_(m.weight, 0, std)
elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.weight, 1.0)
elif isinstance(m, RadialFunction):
m.apply(eqv2_uniform_init_linear_weights)


def eqv2_uniform_init_linear_weights(m):
if isinstance(m, torch.nn.Linear):
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
std = 1 / math.sqrt(m.in_features)
torch.nn.init.uniform_(m.weight, -std, std)


@registry.register_model("equiformer_v2")
class EquiformerV2(nn.Module, GraphModelMixin):
"""
Expand Down Expand Up @@ -400,8 +423,7 @@ def __init__(
requires_grad=False,
)

self.apply(self._init_weights)
self.apply(self._uniform_init_rad_func_linear_weights)
self.apply(partial(eqv2_init_weights, weight_init=self.weight_init))

def _init_gp_partitions(
self,
Expand Down Expand Up @@ -630,31 +652,6 @@ def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
def num_params(self):
return sum(p.numel() for p in self.parameters())

def _init_weights(self, m):
if isinstance(m, (torch.nn.Linear, SO3_LinearV2)):
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
if self.weight_init == "normal":
std = 1 / math.sqrt(m.in_features)
torch.nn.init.normal_(m.weight, 0, std)
elif self.weight_init == "uniform":
self._uniform_init_linear_weights(m)

elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.weight, 1.0)

def _uniform_init_rad_func_linear_weights(self, m):
if isinstance(m, RadialFunction):
m.apply(self._uniform_init_linear_weights)

def _uniform_init_linear_weights(self, m):
if isinstance(m, (torch.nn.Linear, SO3_LinearV2)):
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
std = 1 / math.sqrt(m.in_features)
torch.nn.init.uniform_(m.weight, -std, std)

@torch.jit.ignore
def no_weight_decay(self) -> set:
no_wd_list = []
Expand Down Expand Up @@ -852,8 +849,7 @@ def __init__(self, backbone):
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)
self.apply(backbone._init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))

def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
node_energy = self.energy_block(emb["node_embedding"])
Expand Down Expand Up @@ -898,8 +894,7 @@ def __init__(self, backbone):
backbone.use_sep_s2_act,
alpha_drop=0.0,
)
self.apply(backbone._init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
if self.activation_checkpoint:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from .rank2 import Rank2SymmetricTensorHead

__all__ = ["Rank2SymmetricTensorHead"]
Loading

0 comments on commit e380c66

Please sign in to comment.