Skip to content

Commit

Permalink
use model generate graph
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Aug 15, 2024
1 parent c7debdb commit 00554c7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
9 changes: 4 additions & 5 deletions src/fairchem/core/common/relaxation/optimizers/optimizable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch_scatter import scatter

from fairchem.core.common.relaxation.ase_utils import batch_to_atoms
from fairchem.core.common.utils import radius_graph_pbc

# unreleased ASE has Optimizable, last released version 3.22.1 does not
# thankfully we can get away with backwards compatibility by creating a dummy
Expand Down Expand Up @@ -315,10 +314,10 @@ def get_atoms_list(self) -> list[Atoms]:

def update_graph(self):
"""Update the graph if model does not use otf_graph."""
edge_index, cell_offsets, num_neighbors = radius_graph_pbc(self.batch, 6, 50)
self.batch.edge_index = edge_index
self.batch.cell_offsets = cell_offsets
self.batch.neighbors = num_neighbors
graph = self.trainer.model.generate_graph(self.batch)
self.batch.edge_index = graph.edge_index
self.batch.cell_offsets = graph.cell_offsets
self.batch.neighbors = graph.neighbors
if self.transform is not None:
self.batch = self.transform(self.batch)

Expand Down
15 changes: 7 additions & 8 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,13 @@ def generate_graph(
use_pbc_single = use_pbc_single or self.use_pbc_single
otf_graph = otf_graph or self.otf_graph

if enforce_max_neighbors_strictly is not None:
pass
elif hasattr(self, "enforce_max_neighbors_strictly"):
# Not all models will have this attribute
enforce_max_neighbors_strictly = self.enforce_max_neighbors_strictly
else:
# Default to old behavior
enforce_max_neighbors_strictly = True
if enforce_max_neighbors_strictly is None:
if hasattr(self, "enforce_max_neighbors_strictly"):
# Not all models will have this attribute
enforce_max_neighbors_strictly = self.enforce_max_neighbors_strictly
else:
# Default to old behavior
enforce_max_neighbors_strictly = True

if not otf_graph:
try:
Expand Down

0 comments on commit 00554c7

Please sign in to comment.