From 00554c7aff793afa2fa9c6b0ef23d7c572462b99 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 15 Aug 2024 16:09:26 -0700 Subject: [PATCH] use model generate graph --- .../common/relaxation/optimizers/optimizable.py | 9 ++++----- src/fairchem/core/models/base.py | 15 +++++++-------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/fairchem/core/common/relaxation/optimizers/optimizable.py b/src/fairchem/core/common/relaxation/optimizers/optimizable.py index 31c7b9aae..8aed07a6c 100644 --- a/src/fairchem/core/common/relaxation/optimizers/optimizable.py +++ b/src/fairchem/core/common/relaxation/optimizers/optimizable.py @@ -20,7 +20,6 @@ from torch_scatter import scatter from fairchem.core.common.relaxation.ase_utils import batch_to_atoms -from fairchem.core.common.utils import radius_graph_pbc # unreleased ASE has Optimizable, last released version 3.22.1 does not # thankfully we can get away with backwards compatibility by creating a dummy @@ -315,10 +314,10 @@ def get_atoms_list(self) -> list[Atoms]: def update_graph(self): """Update the graph if model does not use otf_graph.""" - edge_index, cell_offsets, num_neighbors = radius_graph_pbc(self.batch, 6, 50) - self.batch.edge_index = edge_index - self.batch.cell_offsets = cell_offsets - self.batch.neighbors = num_neighbors + graph = self.trainer.model.generate_graph(self.batch) + self.batch.edge_index = graph.edge_index + self.batch.cell_offsets = graph.cell_offsets + self.batch.neighbors = graph.neighbors if self.transform is not None: self.batch = self.transform(self.batch) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 8ce8f3fcb..e6e4a4960 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -61,14 +61,13 @@ def generate_graph( use_pbc_single = use_pbc_single or self.use_pbc_single otf_graph = otf_graph or self.otf_graph - if enforce_max_neighbors_strictly is not None: - pass - elif hasattr(self, "enforce_max_neighbors_strictly"): - # Not all models will have this attribute - enforce_max_neighbors_strictly = self.enforce_max_neighbors_strictly - else: - # Default to old behavior - enforce_max_neighbors_strictly = True + if enforce_max_neighbors_strictly is None: + if hasattr(self, "enforce_max_neighbors_strictly"): + # Not all models will have this attribute + enforce_max_neighbors_strictly = self.enforce_max_neighbors_strictly + else: + # Default to old behavior + enforce_max_neighbors_strictly = True if not otf_graph: try: