From b7732cb22c8e483187a89e71354341d7e4087ed6 Mon Sep 17 00:00:00 2001 From: "T.Tian" Date: Sat, 19 Aug 2023 01:38:16 -0700 Subject: [PATCH 1/3] fix for single calc calls --- sparc/api.py | 11 +++--- sparc/calculator.py | 88 +++++++++++++++++++++++++-------------------- 2 files changed, 54 insertions(+), 45 deletions(-) diff --git a/sparc/api.py b/sparc/api.py index a0416766..ad68d61f 100644 --- a/sparc/api.py +++ b/sparc/api.py @@ -30,8 +30,7 @@ def get_parameter_dict(self, parameter): parameter = parameter.upper() if parameter not in self.parameters.keys(): raise KeyError( - f"Parameter {parameter} is not known to " - f"SPARC {self.sparc_version}!" + f"Parameter {parameter} is not known to " f"SPARC {self.sparc_version}!" ) return self.parameters[parameter] @@ -187,7 +186,7 @@ def convert_value_to_string(self, parameter, value): # Be aware of bool values! string = str(int(value)) elif dtype == "double": - string = "{:g}".format(float(value)) + string = "{:.16f}".format(float(value)) elif dtype in ("integer array", "double array"): string = _array_to_string(value, dtype) else: @@ -205,10 +204,8 @@ def _array_to_string(arr, format): if format in ("integer array", "integer"): fmt = "%d" elif format in ("double array", "double"): - fmt = "%g" - np.savetxt( - buf, arr, delimiter=" ", fmt=fmt, header="", footer="", newline="\n" - ) + fmt = "%.16f" + np.savetxt(buf, arr, delimiter=" ", fmt=fmt, header="", footer="", newline="\n") # Return the string output of the buffer with # whitespaces removed return buf.getvalue().strip() diff --git a/sparc/calculator.py b/sparc/calculator.py index ecf11645..ddc0d53b 100644 --- a/sparc/calculator.py +++ b/sparc/calculator.py @@ -157,6 +157,25 @@ def resort(self): """ return self.sparc_bundle.sorting["resort"] + def check_state(self, atoms, tol=1e-14): + """Updated check_state method. + By default self.atoms (cached from output files) contains the initial_magmoms, + so we add a zero magmoms to the atoms for comparison if it does not exist. + + tol=1e-14 should be enough if the values are written as ":.16f" + """ + atoms_copy = atoms.copy() + if "initial_magmoms" not in atoms_copy.arrays: + atoms_copy.set_initial_magnetic_moments( + [ + 0, + ] + * len(atoms_copy) + ) + # First we check for default changes + system_changes = FileIOCalculator.check_state(self, atoms_copy, tol=tol) + return system_changes + def _make_command(self, extras=""): """Use $ASE_SPARC_COMMAND or self.command to determine the command as a last resort, if `sparc` exists in the PATH, use that information @@ -191,9 +210,7 @@ def _make_command(self, extras=""): self.command = command_env return f"{self.command} {extras}" - def calculate( - self, atoms=None, properties=["energy"], system_changes=all_changes - ): + def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): """Perform a calculation step""" Calculator.calculate(self, atoms, properties, system_changes) self.write_input(self.atoms, properties, system_changes) @@ -244,21 +261,26 @@ def _check_input_exclusion(self, input_parameters, atoms=None): count += 1 if count > 1: # TODO: change to ExclusionParameterError - raise ValueError("ECUT, MESH_SPACING, FD_GRID cannot be specified simultaneously!") + raise ValueError( + "ECUT, MESH_SPACING, FD_GRID cannot be specified simultaneously!" + ) # Rule 2: LATVEC_SCALE, CELL if ("LATVEC_SCALE" in input_parameters) and ("CELL" in input_parameters): # TODO: change to ExclusionParameterError - raise ValueError("LATVEC_SCALE and CELL cannot be specified simultaneously!") - + raise ValueError( + "LATVEC_SCALE and CELL cannot be specified simultaneously!" + ) + # When the cell is provided via ase object, we will forbid user to provide # LATVEC, LATVEC_SCALE or CELL # TODO: make sure the rule makes sense for molecules - if (atoms is not None): + if atoms is not None: if any([p in input_parameters for p in ["LATVEC", "LATVEC_SCALE", "CELL"]]): - raise ValueError("When passing an ase atoms object, LATVEC, LATVEC_SCALE or CELL cannot be set simultaneously!") + raise ValueError( + "When passing an ase atoms object, LATVEC, LATVEC_SCALE or CELL cannot be set simultaneously!" + ) - def _check_minimal_input(self, input_parameters): """Check if the minimal input set is satisfied @@ -269,10 +291,12 @@ def _check_minimal_input(self, input_parameters): # TODO: change to MissingParameterError raise ValueError(f"Parameter {param} is not provided.") # At least one from ECUT, MESH_SPACING and FD_GRID must be provided - if not any([param in input_parameters for param in ("ECUT", "MESH_SPACING", "FD_GRID")]): - raise ValueError("You should provide at least one of ECUT, MESH_SPACING or FD_GRID.") - - + if not any( + [param in input_parameters for param in ("ECUT", "MESH_SPACING", "FD_GRID")] + ): + raise ValueError( + "You should provide at least one of ECUT, MESH_SPACING or FD_GRID." + ) def write_input(self, atoms, properties=[], system_changes=[]): """Create input files via SparcBundle""" @@ -283,7 +307,6 @@ def write_input(self, atoms, properties=[], system_changes=[]): converted_params = self._convert_special_params(atoms=atoms) input_parameters = converted_params.copy() input_parameters.update(self.valid_params) - # Make sure desired properties are always ensured, but we don't modify the user inputs if "forces" in properties: @@ -300,7 +323,6 @@ def write_input(self, atoms, properties=[], system_changes=[]): self._check_input_exclusion(input_parameters, atoms=atoms) self._check_minimal_input(input_parameters) - self.sparc_bundle._write_ion_and_inpt( atoms=atoms, @@ -345,10 +367,7 @@ def execute(self): errorcode = self.proc.returncode if errorcode > 0: - msg = ( - f"SPARC failed with command {command}" - f"with error code {errorcode}" - ) + msg = f"SPARC failed with command {command}" f"with error code {errorcode}" raise RuntimeError(msg) return @@ -366,9 +385,7 @@ def read_results(self): """Parse from the SparcBundle""" # TODO: try use cache? # self.sparc_bundle.read_raw_results() - last = self.sparc_bundle.convert_to_ase( - indices=-1, include_all_files=False - ) + last = self.sparc_bundle.convert_to_ase(indices=-1, include_all_files=False) self.atoms = last.copy() self.results.update(last.calc.results) @@ -435,7 +452,6 @@ def _sanitize_kwargs(self, kwargs): warn(f"Input parameter {key} does not have a valid value!") return valid_params, special_params - def _convert_special_params(self, atoms=None): """Convert ASE-compatible parameters to SPARC compatible ones parameters like `h`, `nbands` may need atoms information @@ -467,9 +483,13 @@ def _convert_special_params(self, atoms=None): raise ValueError( "Must have an active atoms object to convert h --> gpts!" ) - if any([p in self.valid_params for p in ("FD_GRID", "ECUT", "MESH_SPACING")]): - warn("You have specified one of FD_GRID, ECUT or MESH_SPACING, " - "conversion of h to mesh grid is ignored.") + if any( + [p in self.valid_params for p in ("FD_GRID", "ECUT", "MESH_SPACING")] + ): + warn( + "You have specified one of FD_GRID, ECUT or MESH_SPACING, " + "conversion of h to mesh grid is ignored." + ) else: # TODO: is there any limitation for parallelization? gpts = h2gpts(h, atoms.cell) @@ -482,9 +502,7 @@ def _convert_special_params(self, atoms=None): converted_sparc_params["FD_GRID"] = gpts else: # TODO: customize error - raise ValueError( - f"Input parameter gpts has invalid value {gpts}" - ) + raise ValueError(f"Input parameter gpts has invalid value {gpts}") # kpts if "kpts" in params: @@ -494,9 +512,7 @@ def _convert_special_params(self, atoms=None): converted_sparc_params["KPOINT_GRID"] = kpts else: # TODO: customize error - raise ValueError( - f"Input parameter kpts has invalid value {kpts}" - ) + raise ValueError(f"Input parameter kpts has invalid value {kpts}") # nbands if "nbands" in params: @@ -507,9 +523,7 @@ def _convert_special_params(self, atoms=None): converted_sparc_params["NSTATES"] = nbands else: # TODO: customize error - raise ValueError( - f"Input parameter nbands has invalid value {nbands}" - ) + raise ValueError(f"Input parameter nbands has invalid value {nbands}") # convergence is a dict if "convergence" in params: @@ -564,9 +578,7 @@ def interpret_grid_input(self, atoms, **kwargs): def interpret_kpoint_input(self, atoms, **kwargs): return None - @deprecated( - "Please use SPARC.set instead for setting downsampling parameter" - ) + @deprecated("Please use SPARC.set instead for setting downsampling parameter") def interpret_downsampling_input(self, atoms, **kwargs): return None From fe2cbffa1c0b0008863ea818fcd50d7b05dd675e Mon Sep 17 00:00:00 2001 From: "T.Tian" Date: Sat, 19 Aug 2023 08:58:03 -0700 Subject: [PATCH 2/3] add quick patch for un-cached results --- .gitignore | 1 + examples/ex1-ase-optimize.py | 5 +-- sparc/api.py | 11 +++--- sparc/calculator.py | 70 +++++++++++++++++++++++++++++------- sparc/io.py | 28 ++++++++------- sparc/sparc_parsers/out.py | 12 +++++-- 6 files changed, 92 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index 333743dc..1ad89ca5 100644 --- a/.gitignore +++ b/.gitignore @@ -770,3 +770,4 @@ Untitled* ex0-*/ al-eos-sparc.traj */ex1-sparc/ +examples/ex1-ase/ diff --git a/examples/ex1-ase-optimize.py b/examples/ex1-ase-optimize.py index 3b824f9c..58e3b9eb 100644 --- a/examples/ex1-ase-optimize.py +++ b/examples/ex1-ase-optimize.py @@ -13,7 +13,7 @@ from ase.optimize.bfgs import BFGS from ase.constraints import FixAtoms -nh3 = molecule("NH3", cell=(6, 6, 6)) +nh3 = molecule("NH3", cell=(8, 8, 8), pbc=True) # Fix the N center nh3.constraints = [FixAtoms([0])] nh3.rattle() @@ -50,6 +50,7 @@ def optimize_ase_lbfgs(): ) atoms.calc = calc opt = BFGS(atoms) + #breakpoint() opt.run(fmax=0.02) e_fin = atoms.get_potential_energy() f_fin = atoms.get_forces() @@ -61,5 +62,5 @@ def optimize_ase_lbfgs(): if __name__ == "__main__": - # optimize_sparc_internal() + optimize_sparc_internal() optimize_ase_lbfgs() diff --git a/sparc/api.py b/sparc/api.py index ad68d61f..5f3f9215 100644 --- a/sparc/api.py +++ b/sparc/api.py @@ -30,7 +30,8 @@ def get_parameter_dict(self, parameter): parameter = parameter.upper() if parameter not in self.parameters.keys(): raise KeyError( - f"Parameter {parameter} is not known to " f"SPARC {self.sparc_version}!" + f"Parameter {parameter} is not known to " + f"SPARC {self.sparc_version}!" ) return self.parameters[parameter] @@ -186,7 +187,7 @@ def convert_value_to_string(self, parameter, value): # Be aware of bool values! string = str(int(value)) elif dtype == "double": - string = "{:.16f}".format(float(value)) + string = "{:.14f}".format(float(value)) elif dtype in ("integer array", "double array"): string = _array_to_string(value, dtype) else: @@ -204,8 +205,10 @@ def _array_to_string(arr, format): if format in ("integer array", "integer"): fmt = "%d" elif format in ("double array", "double"): - fmt = "%.16f" - np.savetxt(buf, arr, delimiter=" ", fmt=fmt, header="", footer="", newline="\n") + fmt = "%.14f" + np.savetxt( + buf, arr, delimiter=" ", fmt=fmt, header="", footer="", newline="\n" + ) # Return the string output of the buffer with # whitespaces removed return buf.getvalue().strip() diff --git a/sparc/calculator.py b/sparc/calculator.py index ddc0d53b..6c72f3fa 100644 --- a/sparc/calculator.py +++ b/sparc/calculator.py @@ -157,12 +157,12 @@ def resort(self): """ return self.sparc_bundle.sorting["resort"] - def check_state(self, atoms, tol=1e-14): + def check_state(self, atoms, tol=1e-9): """Updated check_state method. By default self.atoms (cached from output files) contains the initial_magmoms, so we add a zero magmoms to the atoms for comparison if it does not exist. - tol=1e-14 should be enough if the values are written as ":.16f" + reading a result from the .out file has only precision up to 10 digits """ atoms_copy = atoms.copy() if "initial_magmoms" not in atoms_copy.arrays: @@ -174,6 +174,16 @@ def check_state(self, atoms, tol=1e-14): ) # First we check for default changes system_changes = FileIOCalculator.check_state(self, atoms_copy, tol=tol) + # A few hard-written rules. Wrapping should only affect the position + if "positions" in system_changes: + atoms_copy.wrap() + new_system_changes = FileIOCalculator.check_state( + self, atoms_copy, tol=tol + ) + # TODO: make sure such check only happens for PBC + # the position is wrapped, accept as the same structure + if "positions" not in new_system_changes: + system_changes.remove("positions") return system_changes def _make_command(self, extras=""): @@ -210,8 +220,16 @@ def _make_command(self, extras=""): self.command = command_env return f"{self.command} {extras}" - def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): + def calculate( + self, atoms=None, properties=["energy"], system_changes=all_changes + ): """Perform a calculation step""" + # For v1.0.0, we'll only allow pbc=True to make ourselves easier + # TODO: need to have more flexible support for pbc types and check_state + if not all(atoms.pbc): + raise NotImplementedError( + "Non-pbc atoms input has not been tested in the api. Please use pbc=True for now." + ) Calculator.calculate(self, atoms, properties, system_changes) self.write_input(self.atoms, properties, system_changes) self.execute() @@ -266,7 +284,9 @@ def _check_input_exclusion(self, input_parameters, atoms=None): ) # Rule 2: LATVEC_SCALE, CELL - if ("LATVEC_SCALE" in input_parameters) and ("CELL" in input_parameters): + if ("LATVEC_SCALE" in input_parameters) and ( + "CELL" in input_parameters + ): # TODO: change to ExclusionParameterError raise ValueError( "LATVEC_SCALE and CELL cannot be specified simultaneously!" @@ -276,7 +296,12 @@ def _check_input_exclusion(self, input_parameters, atoms=None): # LATVEC, LATVEC_SCALE or CELL # TODO: make sure the rule makes sense for molecules if atoms is not None: - if any([p in input_parameters for p in ["LATVEC", "LATVEC_SCALE", "CELL"]]): + if any( + [ + p in input_parameters + for p in ["LATVEC", "LATVEC_SCALE", "CELL"] + ] + ): raise ValueError( "When passing an ase atoms object, LATVEC, LATVEC_SCALE or CELL cannot be set simultaneously!" ) @@ -292,7 +317,10 @@ def _check_minimal_input(self, input_parameters): raise ValueError(f"Parameter {param} is not provided.") # At least one from ECUT, MESH_SPACING and FD_GRID must be provided if not any( - [param in input_parameters for param in ("ECUT", "MESH_SPACING", "FD_GRID")] + [ + param in input_parameters + for param in ("ECUT", "MESH_SPACING", "FD_GRID") + ] ): raise ValueError( "You should provide at least one of ECUT, MESH_SPACING or FD_GRID." @@ -367,7 +395,10 @@ def execute(self): errorcode = self.proc.returncode if errorcode > 0: - msg = f"SPARC failed with command {command}" f"with error code {errorcode}" + msg = ( + f"SPARC failed with command {command}" + f"with error code {errorcode}" + ) raise RuntimeError(msg) return @@ -385,7 +416,9 @@ def read_results(self): """Parse from the SparcBundle""" # TODO: try use cache? # self.sparc_bundle.read_raw_results() - last = self.sparc_bundle.convert_to_ase(indices=-1, include_all_files=False) + last = self.sparc_bundle.convert_to_ase( + indices=-1, include_all_files=False + ) self.atoms = last.copy() self.results.update(last.calc.results) @@ -484,7 +517,10 @@ def _convert_special_params(self, atoms=None): "Must have an active atoms object to convert h --> gpts!" ) if any( - [p in self.valid_params for p in ("FD_GRID", "ECUT", "MESH_SPACING")] + [ + p in self.valid_params + for p in ("FD_GRID", "ECUT", "MESH_SPACING") + ] ): warn( "You have specified one of FD_GRID, ECUT or MESH_SPACING, " @@ -502,7 +538,9 @@ def _convert_special_params(self, atoms=None): converted_sparc_params["FD_GRID"] = gpts else: # TODO: customize error - raise ValueError(f"Input parameter gpts has invalid value {gpts}") + raise ValueError( + f"Input parameter gpts has invalid value {gpts}" + ) # kpts if "kpts" in params: @@ -512,7 +550,9 @@ def _convert_special_params(self, atoms=None): converted_sparc_params["KPOINT_GRID"] = kpts else: # TODO: customize error - raise ValueError(f"Input parameter kpts has invalid value {kpts}") + raise ValueError( + f"Input parameter kpts has invalid value {kpts}" + ) # nbands if "nbands" in params: @@ -523,7 +563,9 @@ def _convert_special_params(self, atoms=None): converted_sparc_params["NSTATES"] = nbands else: # TODO: customize error - raise ValueError(f"Input parameter nbands has invalid value {nbands}") + raise ValueError( + f"Input parameter nbands has invalid value {nbands}" + ) # convergence is a dict if "convergence" in params: @@ -578,7 +620,9 @@ def interpret_grid_input(self, atoms, **kwargs): def interpret_kpoint_input(self, atoms, **kwargs): return None - @deprecated("Please use SPARC.set instead for setting downsampling parameter") + @deprecated( + "Please use SPARC.set instead for setting downsampling parameter" + ) def interpret_downsampling_input(self, atoms, **kwargs): return None diff --git a/sparc/io.py b/sparc/io.py index dd465e81..693512a4 100644 --- a/sparc/io.py +++ b/sparc/io.py @@ -62,7 +62,12 @@ class SparcBundle: psp_env = ["SPARC_PSP_PATH", "SPARC_PP_PATH"] def __init__( - self, directory, mode="r", atoms=None, label=None, psp_dir=None, + self, + directory, + mode="r", + atoms=None, + label=None, + psp_dir=None, ): self.directory = Path(directory) self.mode = mode.lower() @@ -378,8 +383,7 @@ def convert_to_ase(self, index=-1, include_all_files=False, **kwargs): if images is not None: if calc_results is not None: - images = self._make_singlepoint( - calc_results, images, entry) + images = self._make_singlepoint(calc_results, images, entry) res_images.extend(images) if isinstance(index, int): @@ -403,9 +407,7 @@ def _make_singlepoint(self, calc_results, images, raw_results): sp.results.update(res) sp.name = "sparc" sp.kpts = ( - raw_results["inpt"] - .get("params", {}) - .get("KPOINT_GRID", None) + raw_results["inpt"].get("params", {}).get("KPOINT_GRID", None) ) # There may be a better way handling the parameters... sp.parameters = raw_results["inpt"].get("params", {}) @@ -502,19 +504,19 @@ def _extract_geopt_results(self, raw_results, index=":"): if "ase_cell" in result: atoms.set_cell(result["ase_cell"]) else: - # For geopt and RELAX=2 (cell relaxation), + # For geopt and RELAX=2 (cell relaxation), # the positions may not be written in .geopt file relax_flag = raw_results["inpt"]["params"].get("RELAX_FLAG", 0) if relax_flag != 2: raise ValueError( - ".geopt file missing positions while RELAX!=2. " - "Please check your setup ad output files." - ) + ".geopt file missing positions while RELAX!=2. " + "Please check your setup ad output files." + ) if "ase_cell" not in result: raise ValueError( - "Cannot recover positions from .geopt file due to missing cell information. " - "Please check your setup ad output files." - ) + "Cannot recover positions from .geopt file due to missing cell information. " + "Please check your setup ad output files." + ) atoms.set_cell(result["ase_cell"], scale_atoms=True) calc_results.append(partial_result) ase_images.append(atoms) diff --git a/sparc/sparc_parsers/out.py b/sparc/sparc_parsers/out.py index 6eaa1b34..60002b12 100644 --- a/sparc/sparc_parsers/out.py +++ b/sparc/sparc_parsers/out.py @@ -66,10 +66,14 @@ def _read_sparc_version(header): date_str = match[0].strip().replace(",", " ") # Accept both abbreviate and full month name try: - date_version = datetime.strptime(date_str, "%B %d %Y").strftime("%Y.%m.%d") + date_version = datetime.strptime(date_str, "%B %d %Y").strftime( + "%Y.%m.%d" + ) except ValueError: try: - date_version = datetime.strptime(date_str, "%b %d %Y").strftime("%Y.%m.%d") + date_version = datetime.strptime(date_str, "%b %d %Y").strftime( + "%Y.%m.%d" + ) except ValueError: warn("Cannot fetch SPARC version information!") date_version = None @@ -145,7 +149,9 @@ def _read_scfs(contents): conv_header = re.split(r"\s{3,}", conv_lines[0]) # In some cases the ionic step ends with a warning message # To be flexible, we only extract lines starting with a number - conv_array = np.genfromtxt([l for l in conv_lines if l.split()[0].isdigit()], dtype=float) + conv_array = np.genfromtxt( + [l for l in conv_lines if l.split()[0].isdigit()], dtype=float + ) # TODO: the meaning of the header should me split to the width conv_dict = {} From 0e50c9571abebfa9c069785e09cef77ae8c94261 Mon Sep 17 00:00:00 2001 From: "T.Tian" Date: Sat, 19 Aug 2023 09:20:14 -0700 Subject: [PATCH 3/3] add tests for cached results --- tests/test_calculator.py | 41 +++++++++++++++++++++++++++++++++---- tests/test_geopt_parser.py | 2 +- tests/test_output_parser.py | 1 + tests/test_read_sparc.py | 3 ++- 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 84266c53..e783b48e 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -4,10 +4,10 @@ def test_h_parameter(): - """Parameter h will be overwritten by any of FD_GRID, MESH_SPACING, ECUT - """ + """Parameter h will be overwritten by any of FD_GRID, MESH_SPACING, ECUT""" from sparc.calculator import SPARC from ase.build import bulk + atoms = bulk("Al", cubic=True) with tempfile.TemporaryDirectory() as tmpdir: calc = SPARC(h=0.2, directory=tmpdir) @@ -39,6 +39,7 @@ def test_h_parameter(): def test_conflict_param(): from sparc.calculator import SPARC from ase.build import bulk + atoms = bulk("Al", cubic=True) with tempfile.TemporaryDirectory() as tmpdir: calc = SPARC(h=0.2, directory=tmpdir, FD_GRID=[25, 25, 25], ECUT=25) @@ -47,8 +48,9 @@ def test_conflict_param(): calc.write_input(atoms) with tempfile.TemporaryDirectory() as tmpdir: - calc = SPARC(h=0.2, directory=tmpdir, FD_GRID=[ - 25, 25, 25], MESH_SPACING=0.4) + calc = SPARC( + h=0.2, directory=tmpdir, FD_GRID=[25, 25, 25], MESH_SPACING=0.4 + ) # FD_GRID and ECUT are conflict, but only detected during the writing with pytest.raises(Exception): calc.write_input(atoms) @@ -57,6 +59,7 @@ def test_conflict_param(): def test_cell_param(): from sparc.calculator import SPARC from ase.build import bulk + atoms = bulk("Al", cubic=True) with tempfile.TemporaryDirectory() as tmpdir: calc = SPARC(h=0.2, directory=tmpdir, CELL=[10, 10, 10]) @@ -68,6 +71,7 @@ def test_cell_param(): def test_unknown_params(): from sparc.calculator import SPARC from ase.build import bulk + atoms = bulk("Al", cubic=True) with tempfile.TemporaryDirectory() as tmpdir: with pytest.raises(Exception): @@ -77,9 +81,38 @@ def test_unknown_params(): def test_label(): from sparc.calculator import SPARC from ase.build import bulk + atoms = bulk("Al", cubic=True) with tempfile.TemporaryDirectory() as tmpdir: calc = SPARC(h=0.2, directory=tmpdir, label="test_label") assert calc.label == "test_label" calc.write_input(atoms) assert (Path(tmpdir) / "test_label.inpt").exists() + + +def test_cache_results(): + # Test if the calculation results are cached (same structure) + from sparc.calculator import SPARC + from ase.build import molecule + from pathlib import Path + + nh3 = molecule("NH3", cell=(8, 8, 8), pbc=True) + nh3.rattle() + + dummy_calc = SPARC() + try: + cmd = dummy_calc._make_command() + except EnvironmentError: + print("Skip test since no sparc command found") + return + with tempfile.TemporaryDirectory() as tmpdir: + calc = SPARC( + h=0.3, kpts=(1, 1, 1), xc="pbe", print_forces=True, directory=tmpdir + ) + nh3.calc = calc + forces = nh3.get_forces() + # make sure no more calculations are needed + assert len(calc.check_state(nh3)) == 0 + energy = nh3.get_potential_energy() + static_files = list(Path(tmpdir).glob("*.static*")) + assert len(static_files) == 1 diff --git a/tests/test_geopt_parser.py b/tests/test_geopt_parser.py index 788c44e0..2dc27c6b 100644 --- a/tests/test_geopt_parser.py +++ b/tests/test_geopt_parser.py @@ -53,4 +53,4 @@ def test_geopt_parser_relax2(): assert "positions" not in step assert "stress" in step assert "cell" in step - assert "latvec" in step \ No newline at end of file + assert "latvec" in step diff --git a/tests/test_output_parser.py b/tests/test_output_parser.py index f4bb4363..4ddf3e42 100644 --- a/tests/test_output_parser.py +++ b/tests/test_output_parser.py @@ -10,6 +10,7 @@ def test_output_date_parser(): from sparc.sparc_parsers.out import _read_sparc_version + header1 = """*************************************************************************** * SPARC (version Feb 03, 2023) * * Copyright (c) 2020 Material Physics & Mechanics Group, Georgia Tech * diff --git a/tests/test_read_sparc.py b/tests/test_read_sparc.py index 05bcc530..ee89d49c 100644 --- a/tests/test_read_sparc.py +++ b/tests/test_read_sparc.py @@ -6,6 +6,7 @@ curdir = Path(__file__).parent test_output_dir = curdir / "outputs" + def test_read_sparc_all(): from sparc.io import read_sparc from sparc.common import repo_dir @@ -14,4 +15,4 @@ def test_read_sparc_all(): if bundle.name not in ("Al_multi_geopt.sparc",): results = read_sparc(bundle) else: - results = read_sparc(bundle, include_all_files=True) \ No newline at end of file + results = read_sparc(bundle, include_all_files=True)