-
Notifications
You must be signed in to change notification settings - Fork 261
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
106 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
from ase import build | ||
|
||
from fairchem.core.common.relaxation.ase_utils import OCPCalculator | ||
from fairchem.core.datasets import data_list_collater | ||
from fairchem.core.preprocessing.atoms_to_graphs import AtomsToGraphs | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def calculator(tmp_path_factory): | ||
dir = tmp_path_factory.mktemp("checkpoints") | ||
calc = OCPCalculator( | ||
model_name="EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=dir, seed=0 | ||
) | ||
# TODO debug this | ||
# removing amp so that we always get float32 predictions | ||
calc.trainer.scaler = None | ||
return calc | ||
|
||
|
||
@pytest.fixture() | ||
def atoms_list(): | ||
atoms_list = [ | ||
build.bulk("Cu", "fcc", a=3.8, cubic=True), | ||
build.bulk("NaCl", crystalstructure="rocksalt", a=5.8), | ||
] | ||
for atoms in atoms_list: | ||
atoms.rattle(stdev=0.05, seed=0) | ||
return atoms_list | ||
|
||
|
||
@pytest.fixture() | ||
def batch(atoms_list): | ||
a2g = AtomsToGraphs(r_edges=False, r_pbc=True) | ||
return data_list_collater([a2g.convert(atoms) for atoms in atoms_list]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
|
||
from itertools import product | ||
|
||
import numpy.testing as npt | ||
import pytest | ||
from ase.io import read | ||
from ase.optimize import LBFGS as LBFGS_ASE | ||
|
||
from fairchem.core.common.relaxation import OptimizableBatch | ||
from fairchem.core.common.relaxation.optimizers import LBFGS | ||
from fairchem.core.modules.evaluator import min_diff | ||
|
||
|
||
def test_lbfgs_relaxation(atoms_list, batch, calculator): | ||
"""Tests batch relaxation using fairchem LBFGS optimizer.""" | ||
obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False) | ||
|
||
# optimize atoms one-by-one | ||
for atoms in atoms_list: | ||
atoms.calc = calculator | ||
opt = LBFGS_ASE(atoms, damping=0.8, alpha=70.0) | ||
opt.run(0.01, 20) | ||
|
||
# optimize atoms in batch using ASE | ||
batch_optimizer = LBFGS(obatch, damping=0.8, alpha=70.0) | ||
batch_optimizer.run(0.01, 20) | ||
|
||
# compare energy and atom positions, this needs pretty slack tols but that should be ok | ||
for a1, a2 in zip(atoms_list, obatch.get_atoms_list()): | ||
assert a1.get_potential_energy() / len(a1) == pytest.approx( | ||
a2.get_potential_energy() / len(a2), abs=0.05 | ||
) | ||
diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc) | ||
npt.assert_allclose(diff, 0, atol=0.01) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("save_full_traj", "steps"), list(product((True, False), (0, 1, 5))) | ||
) | ||
def test_lbfgs_write_trajectory(save_full_traj, steps, batch, calculator, tmp_path): | ||
obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False) | ||
batch_optimizer = LBFGS( | ||
obatch, | ||
save_full_traj=save_full_traj, | ||
traj_dir=tmp_path, | ||
traj_names=[f"system-{i}" for i in range(len(batch))], | ||
) | ||
|
||
batch_optimizer.run(0.01, steps=steps) | ||
|
||
# check that trajectory files where written | ||
traj_files = list(tmp_path.glob("*.traj")) | ||
assert len(traj_files) == len(batch) | ||
|
||
traj_length = 0 if steps == 0 else steps + 1 if save_full_traj else 1 | ||
for file in traj_files: | ||
traj = read(file, ":") | ||
assert len(traj) == traj_length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters