Skip to content

Commit

Permalink
test lbfgs traj writes
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Aug 15, 2024
1 parent 23067c7 commit 8768dbb
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 69 deletions.
4 changes: 2 additions & 2 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def __init__(
r_energy=False,
r_forces=False,
r_distances=False,
r_pbc=trainer.model.use_pbc,
r_edges=trainer.model.otf_graph, # otf graph should not be a property of the model...
r_pbc=self.trainer.model.use_pbc,
r_edges=self.trainer.model.otf_graph, # otf graph should not be a property of the model...
)
self.implemented_properties = list(self.config["outputs"].keys())

Expand Down
15 changes: 6 additions & 9 deletions src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
"""
Args:
optimizable_batch: an optimizable batch which includes a model and a batch of data
maxstep: maximum number of steps to run optimization
maxstep: largest step that any atom is allowed to move
memory: Number of steps to be stored in memory
damping: The calculated step is multiplied with this number before added to the positions.
alpha: Initial guess for the Hessian (curvature of energy surface)
Expand Down Expand Up @@ -71,7 +71,6 @@ def __init__(
assert not self.traj_dir or (
traj_dir and len(traj_names)
), "Trajectory names should be specified to save trajectories"
logging.info("Step Fmax(eV/A)")

def run(self, fmax, steps):
self.fmax = fmax
Expand All @@ -89,18 +88,18 @@ def run(self, fmax, steps):
ase.io.Trajectory(self.traj_dir / f"{name}.traj_tmp", mode="w")
for name in self.traj_names
]
self.write()

iteration = 0
max_forces = self.optimizable.get_max_forces(apply_constraint=True)
while iteration < steps - 1 and not self.optimizable.converged(
logging.info("Step Fmax(eV/A)")
while iteration < steps and not self.optimizable.converged(
forces=None, fmax=self.fmax, max_forces=max_forces
):
logging.info(
f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces.tolist())
)

if self.trajectories is not None and self.save_full:
if self.trajectories is not None and self.save_full is True:
self.write()

self.step(iteration)
Expand All @@ -111,7 +110,7 @@ def run(self, fmax, steps):
f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces.tolist())
)

# save after converged on all iterations ran
# save after converged or all iterations ran
if iteration > 0 and self.trajectories is not None:
self.write()

Expand Down Expand Up @@ -203,7 +202,5 @@ def write(self) -> None:
for atm, traj, mask in zip(
atoms_objects, self.trajectories, self.optimizable.update_mask
):
if (
mask or not self.save_full
): # should this be "if mask or self.save_full"?
if mask:
traj.write(atm)
2 changes: 1 addition & 1 deletion src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def predict(
return predictions

@torch.no_grad
def run_relaxations(self, split="val"):
def run_relaxations(self):
ensure_fitted(self._unwrapped_model)

# When set to true, uses deterministic CUDA scatter ops, if available.
Expand Down
37 changes: 37 additions & 0 deletions tests/core/common/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import pytest
from ase import build

from fairchem.core.common.relaxation.ase_utils import OCPCalculator
from fairchem.core.datasets import data_list_collater
from fairchem.core.preprocessing.atoms_to_graphs import AtomsToGraphs


@pytest.fixture(scope="session")
def calculator(tmp_path_factory):
dir = tmp_path_factory.mktemp("checkpoints")
calc = OCPCalculator(
model_name="EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=dir, seed=0
)
# TODO debug this
# removing amp so that we always get float32 predictions
calc.trainer.scaler = None
return calc


@pytest.fixture()
def atoms_list():
atoms_list = [
build.bulk("Cu", "fcc", a=3.8, cubic=True),
build.bulk("NaCl", crystalstructure="rocksalt", a=5.8),
]
for atoms in atoms_list:
atoms.rattle(stdev=0.05, seed=0)
return atoms_list


@pytest.fixture()
def batch(atoms_list):
a2g = AtomsToGraphs(r_edges=False, r_pbc=True)
return data_list_collater([a2g.convert(atoms) for atoms in atoms_list])
59 changes: 59 additions & 0 deletions tests/core/common/test_lbfgs_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from itertools import product

import numpy.testing as npt
import pytest
from ase.io import read
from ase.optimize import LBFGS as LBFGS_ASE

from fairchem.core.common.relaxation import OptimizableBatch
from fairchem.core.common.relaxation.optimizers import LBFGS
from fairchem.core.modules.evaluator import min_diff


def test_lbfgs_relaxation(atoms_list, batch, calculator):
"""Tests batch relaxation using fairchem LBFGS optimizer."""
obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False)

# optimize atoms one-by-one
for atoms in atoms_list:
atoms.calc = calculator
opt = LBFGS_ASE(atoms, damping=0.8, alpha=70.0)
opt.run(0.01, 20)

# optimize atoms in batch using ASE
batch_optimizer = LBFGS(obatch, damping=0.8, alpha=70.0)
batch_optimizer.run(0.01, 20)

# compare energy and atom positions, this needs pretty slack tols but that should be ok
for a1, a2 in zip(atoms_list, obatch.get_atoms_list()):
assert a1.get_potential_energy() / len(a1) == pytest.approx(
a2.get_potential_energy() / len(a2), abs=0.05
)
diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc)
npt.assert_allclose(diff, 0, atol=0.01)


@pytest.mark.parametrize(
("save_full_traj", "steps"), list(product((True, False), (0, 1, 5)))
)
def test_lbfgs_write_trajectory(save_full_traj, steps, batch, calculator, tmp_path):
obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False)
batch_optimizer = LBFGS(
obatch,
save_full_traj=save_full_traj,
traj_dir=tmp_path,
traj_names=[f"system-{i}" for i in range(len(batch))],
)

batch_optimizer.run(0.01, steps=steps)

# check that trajectory files where written
traj_files = list(tmp_path.glob("*.traj"))
assert len(traj_files) == len(batch)

traj_length = 0 if steps == 0 else steps + 1 if save_full_traj else 1
for file in traj_files:
traj = read(file, ":")
assert len(traj) == traj_length
58 changes: 1 addition & 57 deletions tests/core/common/test_optimizable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import numpy.testing as npt
import pytest
from ase import build
from ase.optimize import BFGS, FIRE, LBFGS

try:
Expand All @@ -13,70 +12,15 @@
from ase.constraints import UnitCellFilter

from fairchem.core.common.relaxation import OptimizableBatch, OptimizableUnitCellBatch
from fairchem.core.common.relaxation.ase_utils import OCPCalculator
from fairchem.core.common.relaxation.optimizers import LBFGS as LBFGS_torch
from fairchem.core.datasets import data_list_collater
from fairchem.core.modules.evaluator import min_diff
from fairchem.core.preprocessing.atoms_to_graphs import AtomsToGraphs


@pytest.fixture(scope="session")
def calculator(tmp_path_factory):
dir = tmp_path_factory.mktemp("checkpoints")
calc = OCPCalculator(
model_name="EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=dir, seed=0
)
# TODO debug this
# removing amp so that we always get float32 predictions
calc.trainer.scaler = None
return calc


@pytest.fixture()
def atoms_list():
atoms_list = [
build.bulk("Cu", "fcc", a=3.8, cubic=True),
build.bulk("NaCl", crystalstructure="rocksalt", a=5.8),
]
for atoms in atoms_list:
atoms.rattle(stdev=0.05, seed=0)
return atoms_list


@pytest.fixture()
def batch(atoms_list):
a2g = AtomsToGraphs(r_edges=False, r_pbc=True)
return data_list_collater([a2g.convert(atoms) for atoms in atoms_list])


@pytest.fixture(params=[FIRE, BFGS])
@pytest.fixture(params=[FIRE, BFGS, LBFGS])
def optimizer_cls(request):
return request.param


def test_lbfgs_relaxation(atoms_list, batch, calculator):
"""Tests batch relaxation using ASE optimizers."""
obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False)

# optimize atoms one-by-one
for atoms in atoms_list:
atoms.calc = calculator
opt = LBFGS(atoms, damping=0.8, alpha=70.0)
opt.run(0.01, 20)

# optimize atoms in batch using ASE
batch_optimizer = LBFGS_torch(obatch, damping=0.8, alpha=70.0)
batch_optimizer.run(0.01, 20)

# compare energy and atom positions, this needs pretty slack tols but that should be ok
for a1, a2 in zip(atoms_list, obatch.get_atoms_list()):
assert a1.get_potential_energy() / len(a1) == pytest.approx(
a2.get_potential_energy() / len(a2), abs=0.05
)
diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc)
npt.assert_allclose(diff, 0, atol=0.01)


def test_ase_relaxation(atoms_list, batch, calculator, optimizer_cls):
"""Tests batch relaxation using ASE optimizers."""
obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=True)
Expand Down

0 comments on commit 8768dbb

Please sign in to comment.