Skip to content

Commit

Permalink
Merge pull request #31 from alchem0x2A/master
Browse files Browse the repository at this point in the history
Fix multiple calls to `SPARC.calculate`
  • Loading branch information
alchem0x2A authored Aug 19, 2023
2 parents 51bd171 + 0e50c95 commit e146aee
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 44 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -770,3 +770,4 @@ Untitled*
ex0-*/
al-eos-sparc.traj
*/ex1-sparc/
examples/ex1-ase/
5 changes: 3 additions & 2 deletions examples/ex1-ase-optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ase.optimize.bfgs import BFGS
from ase.constraints import FixAtoms

nh3 = molecule("NH3", cell=(6, 6, 6))
nh3 = molecule("NH3", cell=(8, 8, 8), pbc=True)
# Fix the N center
nh3.constraints = [FixAtoms([0])]
nh3.rattle()
Expand Down Expand Up @@ -50,6 +50,7 @@ def optimize_ase_lbfgs():
)
atoms.calc = calc
opt = BFGS(atoms)
#breakpoint()
opt.run(fmax=0.02)
e_fin = atoms.get_potential_energy()
f_fin = atoms.get_forces()
Expand All @@ -61,5 +62,5 @@ def optimize_ase_lbfgs():


if __name__ == "__main__":
# optimize_sparc_internal()
optimize_sparc_internal()
optimize_ase_lbfgs()
4 changes: 2 additions & 2 deletions sparc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def convert_value_to_string(self, parameter, value):
# Be aware of bool values!
string = str(int(value))
elif dtype == "double":
string = "{:g}".format(float(value))
string = "{:.14f}".format(float(value))
elif dtype in ("integer array", "double array"):
string = _array_to_string(value, dtype)
else:
Expand All @@ -205,7 +205,7 @@ def _array_to_string(arr, format):
if format in ("integer array", "integer"):
fmt = "%d"
elif format in ("double array", "double"):
fmt = "%g"
fmt = "%.14f"
np.savetxt(
buf, arr, delimiter=" ", fmt=fmt, header="", footer="", newline="\n"
)
Expand Down
92 changes: 74 additions & 18 deletions sparc/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,35 @@ def resort(self):
"""
return self.sparc_bundle.sorting["resort"]

def check_state(self, atoms, tol=1e-9):
"""Updated check_state method.
By default self.atoms (cached from output files) contains the initial_magmoms,
so we add a zero magmoms to the atoms for comparison if it does not exist.
reading a result from the .out file has only precision up to 10 digits
"""
atoms_copy = atoms.copy()
if "initial_magmoms" not in atoms_copy.arrays:
atoms_copy.set_initial_magnetic_moments(
[
0,
]
* len(atoms_copy)
)
# First we check for default changes
system_changes = FileIOCalculator.check_state(self, atoms_copy, tol=tol)
# A few hard-written rules. Wrapping should only affect the position
if "positions" in system_changes:
atoms_copy.wrap()
new_system_changes = FileIOCalculator.check_state(
self, atoms_copy, tol=tol
)
# TODO: make sure such check only happens for PBC
# the position is wrapped, accept as the same structure
if "positions" not in new_system_changes:
system_changes.remove("positions")
return system_changes

def _make_command(self, extras=""):
"""Use $ASE_SPARC_COMMAND or self.command to determine the command
as a last resort, if `sparc` exists in the PATH, use that information
Expand Down Expand Up @@ -195,6 +224,12 @@ def calculate(
self, atoms=None, properties=["energy"], system_changes=all_changes
):
"""Perform a calculation step"""
# For v1.0.0, we'll only allow pbc=True to make ourselves easier
# TODO: need to have more flexible support for pbc types and check_state
if not all(atoms.pbc):
raise NotImplementedError(
"Non-pbc atoms input has not been tested in the api. Please use pbc=True for now."
)
Calculator.calculate(self, atoms, properties, system_changes)
self.write_input(self.atoms, properties, system_changes)
self.execute()
Expand Down Expand Up @@ -244,21 +279,33 @@ def _check_input_exclusion(self, input_parameters, atoms=None):
count += 1
if count > 1:
# TODO: change to ExclusionParameterError
raise ValueError("ECUT, MESH_SPACING, FD_GRID cannot be specified simultaneously!")
raise ValueError(
"ECUT, MESH_SPACING, FD_GRID cannot be specified simultaneously!"
)

# Rule 2: LATVEC_SCALE, CELL
if ("LATVEC_SCALE" in input_parameters) and ("CELL" in input_parameters):
if ("LATVEC_SCALE" in input_parameters) and (
"CELL" in input_parameters
):
# TODO: change to ExclusionParameterError
raise ValueError("LATVEC_SCALE and CELL cannot be specified simultaneously!")

raise ValueError(
"LATVEC_SCALE and CELL cannot be specified simultaneously!"
)

# When the cell is provided via ase object, we will forbid user to provide
# LATVEC, LATVEC_SCALE or CELL
# TODO: make sure the rule makes sense for molecules
if (atoms is not None):
if any([p in input_parameters for p in ["LATVEC", "LATVEC_SCALE", "CELL"]]):
raise ValueError("When passing an ase atoms object, LATVEC, LATVEC_SCALE or CELL cannot be set simultaneously!")
if atoms is not None:
if any(
[
p in input_parameters
for p in ["LATVEC", "LATVEC_SCALE", "CELL"]
]
):
raise ValueError(
"When passing an ase atoms object, LATVEC, LATVEC_SCALE or CELL cannot be set simultaneously!"
)


def _check_minimal_input(self, input_parameters):
"""Check if the minimal input set is satisfied
Expand All @@ -269,10 +316,15 @@ def _check_minimal_input(self, input_parameters):
# TODO: change to MissingParameterError
raise ValueError(f"Parameter {param} is not provided.")
# At least one from ECUT, MESH_SPACING and FD_GRID must be provided
if not any([param in input_parameters for param in ("ECUT", "MESH_SPACING", "FD_GRID")]):
raise ValueError("You should provide at least one of ECUT, MESH_SPACING or FD_GRID.")


if not any(
[
param in input_parameters
for param in ("ECUT", "MESH_SPACING", "FD_GRID")
]
):
raise ValueError(
"You should provide at least one of ECUT, MESH_SPACING or FD_GRID."
)

def write_input(self, atoms, properties=[], system_changes=[]):
"""Create input files via SparcBundle"""
Expand All @@ -283,7 +335,6 @@ def write_input(self, atoms, properties=[], system_changes=[]):
converted_params = self._convert_special_params(atoms=atoms)
input_parameters = converted_params.copy()
input_parameters.update(self.valid_params)


# Make sure desired properties are always ensured, but we don't modify the user inputs
if "forces" in properties:
Expand All @@ -300,7 +351,6 @@ def write_input(self, atoms, properties=[], system_changes=[]):

self._check_input_exclusion(input_parameters, atoms=atoms)
self._check_minimal_input(input_parameters)


self.sparc_bundle._write_ion_and_inpt(
atoms=atoms,
Expand Down Expand Up @@ -435,7 +485,6 @@ def _sanitize_kwargs(self, kwargs):
warn(f"Input parameter {key} does not have a valid value!")
return valid_params, special_params


def _convert_special_params(self, atoms=None):
"""Convert ASE-compatible parameters to SPARC compatible ones
parameters like `h`, `nbands` may need atoms information
Expand Down Expand Up @@ -467,9 +516,16 @@ def _convert_special_params(self, atoms=None):
raise ValueError(
"Must have an active atoms object to convert h --> gpts!"
)
if any([p in self.valid_params for p in ("FD_GRID", "ECUT", "MESH_SPACING")]):
warn("You have specified one of FD_GRID, ECUT or MESH_SPACING, "
"conversion of h to mesh grid is ignored.")
if any(
[
p in self.valid_params
for p in ("FD_GRID", "ECUT", "MESH_SPACING")
]
):
warn(
"You have specified one of FD_GRID, ECUT or MESH_SPACING, "
"conversion of h to mesh grid is ignored."
)
else:
# TODO: is there any limitation for parallelization?
gpts = h2gpts(h, atoms.cell)
Expand Down
28 changes: 15 additions & 13 deletions sparc/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ class SparcBundle:
psp_env = ["SPARC_PSP_PATH", "SPARC_PP_PATH"]

def __init__(
self, directory, mode="r", atoms=None, label=None, psp_dir=None,
self,
directory,
mode="r",
atoms=None,
label=None,
psp_dir=None,
):
self.directory = Path(directory)
self.mode = mode.lower()
Expand Down Expand Up @@ -378,8 +383,7 @@ def convert_to_ase(self, index=-1, include_all_files=False, **kwargs):

if images is not None:
if calc_results is not None:
images = self._make_singlepoint(
calc_results, images, entry)
images = self._make_singlepoint(calc_results, images, entry)
res_images.extend(images)

if isinstance(index, int):
Expand All @@ -403,9 +407,7 @@ def _make_singlepoint(self, calc_results, images, raw_results):
sp.results.update(res)
sp.name = "sparc"
sp.kpts = (
raw_results["inpt"]
.get("params", {})
.get("KPOINT_GRID", None)
raw_results["inpt"].get("params", {}).get("KPOINT_GRID", None)
)
# There may be a better way handling the parameters...
sp.parameters = raw_results["inpt"].get("params", {})
Expand Down Expand Up @@ -502,19 +504,19 @@ def _extract_geopt_results(self, raw_results, index=":"):
if "ase_cell" in result:
atoms.set_cell(result["ase_cell"])
else:
# For geopt and RELAX=2 (cell relaxation),
# For geopt and RELAX=2 (cell relaxation),
# the positions may not be written in .geopt file
relax_flag = raw_results["inpt"]["params"].get("RELAX_FLAG", 0)
if relax_flag != 2:
raise ValueError(
".geopt file missing positions while RELAX!=2. "
"Please check your setup ad output files."
)
".geopt file missing positions while RELAX!=2. "
"Please check your setup ad output files."
)
if "ase_cell" not in result:
raise ValueError(
"Cannot recover positions from .geopt file due to missing cell information. "
"Please check your setup ad output files."
)
"Cannot recover positions from .geopt file due to missing cell information. "
"Please check your setup ad output files."
)
atoms.set_cell(result["ase_cell"], scale_atoms=True)
calc_results.append(partial_result)
ase_images.append(atoms)
Expand Down
12 changes: 9 additions & 3 deletions sparc/sparc_parsers/out.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ def _read_sparc_version(header):
date_str = match[0].strip().replace(",", " ")
# Accept both abbreviate and full month name
try:
date_version = datetime.strptime(date_str, "%B %d %Y").strftime("%Y.%m.%d")
date_version = datetime.strptime(date_str, "%B %d %Y").strftime(
"%Y.%m.%d"
)
except ValueError:
try:
date_version = datetime.strptime(date_str, "%b %d %Y").strftime("%Y.%m.%d")
date_version = datetime.strptime(date_str, "%b %d %Y").strftime(
"%Y.%m.%d"
)
except ValueError:
warn("Cannot fetch SPARC version information!")
date_version = None
Expand Down Expand Up @@ -145,7 +149,9 @@ def _read_scfs(contents):
conv_header = re.split(r"\s{3,}", conv_lines[0])
# In some cases the ionic step ends with a warning message
# To be flexible, we only extract lines starting with a number
conv_array = np.genfromtxt([l for l in conv_lines if l.split()[0].isdigit()], dtype=float)
conv_array = np.genfromtxt(
[l for l in conv_lines if l.split()[0].isdigit()], dtype=float
)
# TODO: the meaning of the header should me split to the width

conv_dict = {}
Expand Down
41 changes: 37 additions & 4 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@


def test_h_parameter():
"""Parameter h will be overwritten by any of FD_GRID, MESH_SPACING, ECUT
"""
"""Parameter h will be overwritten by any of FD_GRID, MESH_SPACING, ECUT"""
from sparc.calculator import SPARC
from ase.build import bulk

atoms = bulk("Al", cubic=True)
with tempfile.TemporaryDirectory() as tmpdir:
calc = SPARC(h=0.2, directory=tmpdir)
Expand Down Expand Up @@ -39,6 +39,7 @@ def test_h_parameter():
def test_conflict_param():
from sparc.calculator import SPARC
from ase.build import bulk

atoms = bulk("Al", cubic=True)
with tempfile.TemporaryDirectory() as tmpdir:
calc = SPARC(h=0.2, directory=tmpdir, FD_GRID=[25, 25, 25], ECUT=25)
Expand All @@ -47,8 +48,9 @@ def test_conflict_param():
calc.write_input(atoms)

with tempfile.TemporaryDirectory() as tmpdir:
calc = SPARC(h=0.2, directory=tmpdir, FD_GRID=[
25, 25, 25], MESH_SPACING=0.4)
calc = SPARC(
h=0.2, directory=tmpdir, FD_GRID=[25, 25, 25], MESH_SPACING=0.4
)
# FD_GRID and ECUT are conflict, but only detected during the writing
with pytest.raises(Exception):
calc.write_input(atoms)
Expand All @@ -57,6 +59,7 @@ def test_conflict_param():
def test_cell_param():
from sparc.calculator import SPARC
from ase.build import bulk

atoms = bulk("Al", cubic=True)
with tempfile.TemporaryDirectory() as tmpdir:
calc = SPARC(h=0.2, directory=tmpdir, CELL=[10, 10, 10])
Expand All @@ -68,6 +71,7 @@ def test_cell_param():
def test_unknown_params():
from sparc.calculator import SPARC
from ase.build import bulk

atoms = bulk("Al", cubic=True)
with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(Exception):
Expand All @@ -77,9 +81,38 @@ def test_unknown_params():
def test_label():
from sparc.calculator import SPARC
from ase.build import bulk

atoms = bulk("Al", cubic=True)
with tempfile.TemporaryDirectory() as tmpdir:
calc = SPARC(h=0.2, directory=tmpdir, label="test_label")
assert calc.label == "test_label"
calc.write_input(atoms)
assert (Path(tmpdir) / "test_label.inpt").exists()


def test_cache_results():
# Test if the calculation results are cached (same structure)
from sparc.calculator import SPARC
from ase.build import molecule
from pathlib import Path

nh3 = molecule("NH3", cell=(8, 8, 8), pbc=True)
nh3.rattle()

dummy_calc = SPARC()
try:
cmd = dummy_calc._make_command()
except EnvironmentError:
print("Skip test since no sparc command found")
return
with tempfile.TemporaryDirectory() as tmpdir:
calc = SPARC(
h=0.3, kpts=(1, 1, 1), xc="pbe", print_forces=True, directory=tmpdir
)
nh3.calc = calc
forces = nh3.get_forces()
# make sure no more calculations are needed
assert len(calc.check_state(nh3)) == 0
energy = nh3.get_potential_energy()
static_files = list(Path(tmpdir).glob("*.static*"))
assert len(static_files) == 1
2 changes: 1 addition & 1 deletion tests/test_geopt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def test_geopt_parser_relax2():
assert "positions" not in step
assert "stress" in step
assert "cell" in step
assert "latvec" in step
assert "latvec" in step
1 change: 1 addition & 0 deletions tests/test_output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

def test_output_date_parser():
from sparc.sparc_parsers.out import _read_sparc_version

header1 = """***************************************************************************
* SPARC (version Feb 03, 2023) *
* Copyright (c) 2020 Material Physics & Mechanics Group, Georgia Tech *
Expand Down
Loading

0 comments on commit e146aee

Please sign in to comment.