Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 14, 2024
1 parent c24cc41 commit 89f8587
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
graph.edge_index,
data_batch, # for GraphDropPath
graph.node_offset,
use_reentrant=False if self.training else True
use_reentrant=not self.training,
)
else:
x = self.blocks[i](
Expand Down Expand Up @@ -908,7 +908,7 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
emb["graph"].edge_distance,
emb["graph"].edge_index,
emb["graph"].node_offset,
use_reentrant=False if self.training else True,
use_reentrant=not self.training,
)
else:
forces = self.force_block(
Expand Down
11 changes: 5 additions & 6 deletions tests/core/models/test_equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

import pytest
import requests
import yaml
from fairchem.core.models.base import HydraModel
import torch
import yaml
from ase.io import read
from torch.nn.parallel.distributed import DistributedDataParallel

Expand Down Expand Up @@ -239,7 +238,7 @@ def _load_hydra_model():
torch.manual_seed(4)
with open(Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml")) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
model = registry.get_model_class("hydra")(yaml_config['model']['backbone'],yaml_config['model']['heads'])
model = registry.get_model_class("hydra")(yaml_config["model"]["backbone"],yaml_config["model"]["heads"])
model.backbone.num_layers = 1
return model

Expand All @@ -264,15 +263,15 @@ def test_eqv2_hydra():
start_rng_state = torch.random.get_rng_state()
outputs_no_ac = no_ac_model(inputs)
print(outputs_no_ac)
torch.autograd.backward(outputs_no_ac['energy'].sum() + outputs_no_ac['forces'].sum())
torch.autograd.backward(outputs_no_ac["energy"].sum() + outputs_no_ac["forces"].sum())

# reset the rng state to the beginning
torch.random.set_rng_state(start_rng_state)
outptuts_ac = ac_model(inputs)
print(outptuts_ac)
torch.autograd.backward(outptuts_ac['energy'].sum() + outptuts_ac['forces'].sum())
torch.autograd.backward(outptuts_ac["energy"].sum() + outptuts_ac["forces"].sum())

ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None}
no_ac_model_grad_dict = {name:p.grad for name, p in no_ac_model.named_parameters() if p.grad is not None}
for name in no_ac_model_grad_dict:
assert torch.allclose(no_ac_model_grad_dict[name], ac_model_grad_dict[name], atol=1e-4)
assert torch.allclose(no_ac_model_grad_dict[name], ac_model_grad_dict[name], atol=1e-4)

0 comments on commit 89f8587

Please sign in to comment.