Skip to content

Commit

Permalink
add locate_api method and explicitly use validator
Browse files Browse the repository at this point in the history
  • Loading branch information
alchem0x2A committed Jan 16, 2024
1 parent 39cf9f1 commit 72bbe20
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 31 deletions.
2 changes: 1 addition & 1 deletion sparc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class SparcAPI:
def __init__(self, json_api=None):
"""Initialize the API from a json file"""
#TODO: like ase io, adapt to both file and fio
# TODO: like ase io, adapt to both file and fio
if json_api is None:
json_api = Path(default_json_api)
else:
Expand Down
14 changes: 9 additions & 5 deletions sparc/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .api import SparcAPI
from .io import SparcBundle
from .utils import _find_default_sparc, deprecated, h2gpts
from .utils import _find_default_sparc, deprecated, h2gpts, locate_api

# Below are a list of ASE-compatible calculator input parameters that are
# in Angstrom/eV units
Expand All @@ -24,7 +24,6 @@
"nbands",
]


defaultAPI = SparcAPI()


Expand Down Expand Up @@ -54,6 +53,8 @@ def __init__(
command=None,
psp_dir=None,
log="sparc.log",
sparc_json_file=None,
sparc_doc_path=None,
**kwargs,
):
# Initialize the calculator but without restart.
Expand All @@ -73,12 +74,14 @@ def __init__(
if label is None:
label = "SPARC" if restart is None else None

self.validator = locate_api(json_file=sparc_json_file, doc_path=sparc_doc_path)
self.sparc_bundle = SparcBundle(
directory=Path(self.directory),
mode="w",
atoms=self.atoms,
label=label,
psp_dir=psp_dir,
validator=self.validator,
)

# Try restarting from an old calculation and set results
Expand Down Expand Up @@ -432,14 +435,15 @@ def get_fermi_level(self):

def _detect_sparc_version(self):
"""Run a short sparc test to determine which sparc is used"""
# TODO: complete the implementation
command = self._make_command()

return None

def _sanitize_kwargs(self, kwargs):
"""Convert known parameters from"""
# print(kwargs)
# TODO: versioned validator
validator = defaultAPI
validator = self.validator
valid_params = {}
special_params = self.default_params.copy()
# TODO: how about overwriting the default parameters?
Expand Down Expand Up @@ -469,7 +473,7 @@ def _convert_special_params(self, atoms=None):
h <--> gpts <--> FD_GRID, only when None of FD_GRID / ECUT or MESH_SPACING is provided
"""
converted_sparc_params = {}
validator = defaultAPI
validator = self.validator
params = self.special_params.copy()

# xc --> EXCHANGE_CORRELATION
Expand Down
15 changes: 9 additions & 6 deletions sparc/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@
from ase.calculators.singlepoint import SinglePointDFTCalculator
from ase.units import GPa, Hartree

# various io formatters
from .api import SparcAPI
from .common import psp_dir as default_psp_dir
from .download_data import is_psp_download_complete
from .sparc_parsers.aimd import _read_aimd
from .sparc_parsers.atoms import atoms_to_dict, dict_to_atoms
from .sparc_parsers.geopt import _read_geopt
from .sparc_parsers.inpt import _read_inpt, _write_inpt

# various io formatters
from .sparc_parsers.ion import _read_ion, _write_ion
from .sparc_parsers.out import _read_out
from .sparc_parsers.pseudopotential import copy_psp_file, parse_psp8_header
from .sparc_parsers.static import _read_static
from .utils import deprecated, string2index

# from .sparc_parsers.ion import read_ion, write_ion
defaultAPI = SparcAPI()


class SparcBundle:
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
atoms=None,
label=None,
psp_dir=None,
validator=defaultAPI,
):
self.directory = Path(directory)
self.mode = mode.lower()
Expand All @@ -85,6 +87,7 @@ def __init__(
# Sorting should be consistent across the whole bundle!
self.sorting = None
self.last_image = -1
self.validator = validator

def _find_files(self):
"""Find all files matching '{label}.*'"""
Expand Down Expand Up @@ -187,8 +190,8 @@ def _read_ion_and_inpt(self):
This method should be rarely used
"""
f_ion, f_inpt = self._indir(".ion"), self._indir(".inpt")
ion_data = _read_ion(f_ion)
inpt_data = _read_inpt(f_inpt)
ion_data = _read_ion(f_ion, validator=self.validator)
inpt_data = _read_inpt(f_inpt, validator=self.validator)
merged_data = {**ion_data, **inpt_data}
return dict_to_atoms(merged_data)

Expand Down Expand Up @@ -254,8 +257,8 @@ def _write_ion_and_inpt(
target_fname = copy_psp_file(origin_psp, target_dir)
block["PSEUDO_POT"] = target_fname

_write_ion(self._indir(".ion"), data_dict)
_write_inpt(self._indir(".inpt"), data_dict)
_write_ion(self._indir(".ion"), data_dict, validator=self.validator)
_write_inpt(self._indir(".inpt"), data_dict, validator=self.validator)
return

def read_raw_results(self, include_all_files=False):
Expand Down
3 changes: 0 additions & 3 deletions sparc/sparc_parsers/aimd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from ..api import SparcAPI
from .utils import strip_comments

# TODO: should allow user to select the api
defaultAPI = SparcAPI()


@reader
def _read_aimd(fileobj):
Expand Down
8 changes: 4 additions & 4 deletions sparc/sparc_parsers/inpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@


@reader
def _read_inpt(fileobj):
def _read_inpt(fileobj, validator=defaultAPI):
contents = fileobj.read()
# label = get_label(fileobj, ".ion")
data, comments = strip_comments(contents)
# We do not read the cell at this time!

# find the index for all atom type lines. They should be at the
# top of their block
inpt_blocks = read_block_input(data, validator=defaultAPI)
inpt_blocks = read_block_input(data, validator=validator)
return {"inpt": {"params": inpt_blocks, "comments": comments}}


@writer
def _write_inpt(fileobj, data_dict):
def _write_inpt(fileobj, data_dict, validator=defaultAPI):
if "inpt" not in data_dict:
raise ValueError("Your dict does not contain inpt section!")

Expand All @@ -45,7 +45,7 @@ def _write_inpt(fileobj, data_dict):
params = inpt_dict["params"]
for key, val in params.items():
# TODO: can we add a multiline argument?
val_string = defaultAPI.convert_value_to_string(key, val)
val_string = validator.convert_value_to_string(key, val)
if (val_string.count("\n") > 0) or (
key
in [
Expand Down
13 changes: 7 additions & 6 deletions sparc/sparc_parsers/ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@
strip_comments,
)

# TODO: should allow user to select the api
defaultAPI = SparcAPI()


class InvalidSortingComment(ValueError):
def __init__(self, message):
self.message = message


defaultAPI = SparcAPI()


@reader
def _read_ion(fileobj):
def _read_ion(fileobj, validator=defaultAPI):
"""
Read information from the .ion file. Note, this method does not return an atoms object,
but rather return a dict. Thus the label option is not necessary to keep
Expand All @@ -54,7 +54,7 @@ def _read_ion(fileobj):
# find the index for all atom type lines. They should be at the top of their block
atom_type_bounds = [i for i, x in enumerate(data) if "ATOM_TYPE" in x] + [len(data)]
atom_blocks = [
read_block_input(data[start:end], validator=defaultAPI)
read_block_input(data[start:end], validator=validator)
for start, end in zip(atom_type_bounds[:-1], atom_type_bounds[1:])
]

Expand All @@ -71,6 +71,7 @@ def _read_ion(fileobj):
def _write_ion(
fileobj,
data_dict,
validator=defaultAPI,
):
"""
Writes the ion file content from the atom_dict
Expand Down Expand Up @@ -133,7 +134,7 @@ def _write_ion(
if val is None:
continue

val_string = defaultAPI.convert_value_to_string(key, val)
val_string = validator.convert_value_to_string(key, val)
# print(val_string)
# TODO: make sure 1 line is accepted
# TODO: write pads to vector lines
Expand Down
11 changes: 5 additions & 6 deletions sparc/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
"""Utilities that are loosely related to core sparc functionalities
"""
import os
import io
import tempfile
import os
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Union
from warnings import warn


import numpy as np

from .api import SparcAPI
from .docparser import SPARCDocParser



def deprecated(message):
def decorator(func):
def new_func(*args, **kwargs):
Expand Down Expand Up @@ -127,6 +125,7 @@ def cprint(content, color=None, bold=False, underline=False, **kwargs):
print(output, **kwargs)
return


def locate_api(json_file=None, doc_path=None):
"""Find the default api in the following order
1) User-provided json file path
Expand All @@ -137,10 +136,10 @@ def locate_api(json_file=None, doc_path=None):
if json_file is not None:
api = SparcAPI(json_file)
return api

if doc_path is None:
doc_path = os.environ.get("SPARC_DOC_PATH", None)

if doc_path is not None:
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
Expand Down

0 comments on commit 72bbe20

Please sign in to comment.