Skip to content

Commit

Permalink
Black pass
Browse files Browse the repository at this point in the history
  • Loading branch information
cdavro committed Oct 3, 2024
1 parent 862f971 commit 31cf9ab
Show file tree
Hide file tree
Showing 52 changed files with 5,719 additions and 1,344 deletions.
42 changes: 33 additions & 9 deletions arcann_training/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,19 @@
parser = argparse.ArgumentParser(description="Deepmd iterative program suite")
parser.add_argument("step_name", type=str, help="Step name")
parser.add_argument("phase_name", type=str, help="Phase name")
parser.add_argument("-v", "--verbose", type=int, default=0, help="verbosity, 0 (default) or 1 (debug)")
parser.add_argument("-i", "--input", type=str, default="input.json", help="name of the input file (with ext)")
parser.add_argument("-c", "--cluster", type=str, default=None, help="name of the fake cluster")
parser.add_argument(
"-v", "--verbose", type=int, default=0, help="verbosity, 0 (default) or 1 (debug)"
)
parser.add_argument(
"-i",
"--input",
type=str,
default="input.json",
help="name of the input file (with ext)",
)
parser.add_argument(
"-c", "--cluster", type=str, default=None, help="name of the fake cluster"
)

if __name__ == "__main__":
args = parser.parse_args()
Expand Down Expand Up @@ -58,15 +68,21 @@
arcann_logger.info(f"-" * 88)
arcann_logger.info(f"-" * 88)
arcann_logger.info(f"ARCANN TRAINING PROGRAM SUITE")
arcann_logger.info(f"Launching: {step_name.capitalize()} - {phase_name.capitalize()}")
arcann_logger.info(
f"Launching: {step_name.capitalize()} - {phase_name.capitalize()}"
)
arcann_logger.info(f"-" * 88)
arcann_logger.info(f"-" * 88)

steps = ["initialization", "training", "exploration", "labeling", "test"]
valid_phases = {}
for step in steps:
step_path = deepmd_iterative_path / step
files = [f.stem for f in step_path.iterdir() if f.is_file() and f.suffix == ".py" and f.stem not in ["__init__", "utils"]]
files = [
f.stem
for f in step_path.iterdir()
if f.is_file() and f.suffix == ".py" and f.stem not in ["__init__", "utils"]
]
valid_phases[step] = files

if step_name not in steps:
Expand All @@ -76,7 +92,9 @@
exit(exit_code)

elif phase_name not in valid_phases.get(step_name, []):
arcann_logger.error(f"Invalid phase for step {step_name}. Valid phases are: {valid_phases[step_name]}")
arcann_logger.error(
f"Invalid phase for step {step_name}. Valid phases are: {valid_phases[step_name]}"
)
arcann_logger.error(f"Aborting...")
exit_code = 1
exit(exit_code)
Expand All @@ -85,7 +103,9 @@
else:
try:
submodule = importlib.import_module(submodule_name)
exit_code = submodule.main(step_name, phase_name, deepmd_iterative_path, fake_cluster, input_fn)
exit_code = submodule.main(
step_name, phase_name, deepmd_iterative_path, fake_cluster, input_fn
)
del submodule, submodule_name
except Exception as e:
exit_code = 1
Expand All @@ -96,9 +116,13 @@
arcann_logger.info(f"-" * 88)
arcann_logger.info(f"-" * 88)
if exit_code == 0:
arcann_logger.info(f"{step_name.capitalize()} - {phase_name.capitalize()} finished")
arcann_logger.info(
f"{step_name.capitalize()} - {phase_name.capitalize()} finished"
)
else:
arcann_logger.error(f"{step_name.capitalize()} - {phase_name.capitalize()} encountered an error")
arcann_logger.error(
f"{step_name.capitalize()} - {phase_name.capitalize()} encountered an error"
)
arcann_logger.info(f"-" * 88)
arcann_logger.info(f"-" * 88)

Expand Down
22 changes: 18 additions & 4 deletions arcann_training/common/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def check_atomsk(atomsk_path: str = None) -> str:
if Path(atomsk_path).is_file():
return str(Path(atomsk_path).resolve())
else:
logger.warning(f"Atomsk path '{atomsk_path}' is invalid. Deleting the atomsk_path variable. Checking environment variable and system path...")
logger.warning(
f"Atomsk path '{atomsk_path}' is invalid. Deleting the atomsk_path variable. Checking environment variable and system path..."
)
del atomsk_path

# Check if ATOMSK_PATH is defined and is valid
Expand Down Expand Up @@ -127,7 +129,9 @@ def check_vmd(vmd_path: str = None) -> str:
if vmd_path is not None and vmd_path != "" and Path(vmd_path).is_file():
return str(Path(vmd_path).resolve())
else:
logger.warning(f"VMD path '{vmd_path}' is invalid. Deleting the vmd_path variable. Checking environment variable and system path...")
logger.warning(
f"VMD path '{vmd_path}' is invalid. Deleting the vmd_path variable. Checking environment variable and system path..."
)
del vmd_path

# Check if VMD_PATH is defined and is valid
Expand Down Expand Up @@ -201,7 +205,12 @@ def check_dcd_is_valid(dcd_path: Path, vmd_bin: Path) -> bool:
quit
"""
# Run VMD script from command line
result = subprocess.run([vmd_bin, "-dispdev", "text", "-e", "/dev/stdin"], input=vmd_script, text=True, capture_output=True)
result = subprocess.run(
[vmd_bin, "-dispdev", "text", "-e", "/dev/stdin"],
input=vmd_script,
text=True,
capture_output=True,
)

# Check if the output contains "Unable to load file "
if "Unable to load file " in result.stdout:
Expand Down Expand Up @@ -239,7 +248,12 @@ def check_nc_is_valid(nc_path: Path, vmd_bin: Path) -> bool:
quit
"""
# Run VMD script from command line
result = subprocess.run([vmd_bin, "-dispdev", "text", "-e", "/dev/stdin"], input=vmd_script, text=True, capture_output=True)
result = subprocess.run(
[vmd_bin, "-dispdev", "text", "-e", "/dev/stdin"],
input=vmd_script,
text=True,
capture_output=True,
)

# Check if the output contains "Unable to load file "
if "Unable to load file " in result.stdout:
Expand Down
8 changes: 6 additions & 2 deletions arcann_training/common/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def change_directory(directory_path: Path) -> None:

# Unittested
@catch_errors_decorator
def check_directory(directory_path: Path, abort_on_error: bool = True, error_msg: str = "default") -> None:
def check_directory(
directory_path: Path, abort_on_error: bool = True, error_msg: str = "default"
) -> None:
"""
Check if the given directory exists and logs a warning or raises an error if it does not.
Expand Down Expand Up @@ -161,7 +163,9 @@ def check_file_existence(
if expected_existence:
message = f"File not found: `{file_path.name}` not in `{file_path.parent}`"
if abort_on_error:
raise FileNotFoundError(message if error_msg == "default" else error_msg)
raise FileNotFoundError(
message if error_msg == "default" else error_msg
)
else:
logger.warning(message if error_msg == "default" else error_msg)
else:
Expand Down
51 changes: 41 additions & 10 deletions arcann_training/common/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def add_key_value_to_dict(dictionary: Dict, key: str, value: Any) -> None:

# Unittested
@catch_errors_decorator
def get_key_in_dict(key: str, input_json: Dict, previous_json: Dict, default_json: Dict) -> Any:
def get_key_in_dict(
key: str, input_json: Dict, previous_json: Dict, default_json: Dict
) -> Any:
"""
Get the value of the key from input JSON, previous JSON or default JSON, and validate its type.
Expand All @@ -119,7 +121,9 @@ def get_key_in_dict(key: str, input_json: Dict, previous_json: Dict, default_jso

# Check if the key is present in any of the JSON, and set the value accordingly.
if key in input_json:
if ( input_json[key] == "default" or input_json[key] == None ) and key in default_json:
if (
input_json[key] == "default" or input_json[key] == None
) and key in default_json:
value = default_json[key]
else:
value = input_json[key]
Expand All @@ -143,7 +147,12 @@ def get_key_in_dict(key: str, input_json: Dict, previous_json: Dict, default_jso

# Unittested
@catch_errors_decorator
def backup_and_overwrite_json_file(json_dict: Dict, file_path: Path, enable_logging: bool = True, read_only: bool = False) -> None:
def backup_and_overwrite_json_file(
json_dict: Dict,
file_path: Path,
enable_logging: bool = True,
read_only: bool = False,
) -> None:
"""
Write a dictionary to a JSON file after creating a backup of the existing file.
Expand Down Expand Up @@ -230,14 +239,18 @@ def load_default_json_file(file_path: Path) -> Dict:
return json.loads(file_content)
else:
# If the file cannot be found, return an empty dictionary and log a warning
logging.warning(f"Default file '{file_path.name}' not found in '{file_path.parent}'.")
logging.warning(
f"Default file '{file_path.name}' not found in '{file_path.parent}'."
)
logging.warning(f"Check your installation")
return {}


# Unittested
@catch_errors_decorator
def load_json_file(file_path: Path, abort_on_error: bool = True, enable_logging: bool = True) -> Dict:
def load_json_file(
file_path: Path, abort_on_error: bool = True, enable_logging: bool = True
) -> Dict:
"""
Load a JSON file from the given file path and return its contents as a dictionary.
Expand Down Expand Up @@ -287,13 +300,20 @@ def load_json_file(file_path: Path, abort_on_error: bool = True, enable_logging:
else:
# If logging is enabled, log information about the creation of the empty dictionary
if enable_logging:
logging.info(f"Creating an empty dictionary: '{file_path.name}' in '{file_path.parent}'.")
logging.info(
f"Creating an empty dictionary: '{file_path.name}' in '{file_path.parent}'."
)
return {}


# Unittested
@catch_errors_decorator
def write_json_file(json_dict: Dict, file_path: Path, enable_logging: bool = True, read_only: bool = False) -> None:
def write_json_file(
json_dict: Dict,
file_path: Path,
enable_logging: bool = True,
read_only: bool = False,
) -> None:
"""
Writes a dictionary to a JSON file, optionally logging the action and setting the file to read-only.
Expand Down Expand Up @@ -341,7 +361,11 @@ def write_json_file(json_dict: Dict, file_path: Path, enable_logging: bool = Tru

# Collapse arrays/lists in the JSON to a single line
pattern = r"(\[)(\s*([^\]]*)\s*)(\])"
replacement = lambda m: m.group(1) + re.sub(r"\s+", " ", m.group(3)).rstrip() + m.group(4)
replacement = (
lambda m: m.group(1)
+ re.sub(r"\s+", " ", m.group(3)).rstrip()
+ m.group(4)
)
json_str = re.sub(pattern, replacement, json_str)
json_str = re.sub(r"\],\s+\[", "], [", json_str)
json_str = re.sub(r"\]\s+\]", "]]", json_str)
Expand Down Expand Up @@ -396,14 +420,21 @@ def convert_control_to_input(control_json: Dict, main_json: Dict) -> Dict:
# Iterate over keys in main_json["systems_auto"]
for system_auto in main_json.get("systems_auto", {}):
if system_auto in control_json.get("systems_auto", {}):
input_json[key].append(control_json["systems_auto"][system_auto].get(key, None))
input_json[key].append(
control_json["systems_auto"][system_auto].get(key, None)
)

return input_json


# TODO: Add tests for this function
@catch_errors_decorator
def replace_values_by_key_name(d: Union[Dict[str, Any], List[Any]], key_name: str, new_value: Any, parent_key: str = "") -> None:
def replace_values_by_key_name(
d: Union[Dict[str, Any], List[Any]],
key_name: str,
new_value: Any,
parent_key: str = "",
) -> None:
"""
Recursively finds and replaces the values of all keys (and subkeys) with the specified name within a dictionary or list of dictionaries.
Expand Down
9 changes: 8 additions & 1 deletion arcann_training/common/lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,14 @@ def read_lammps_data(
if num_atom_types == None:
error_msg = "The number of atom types was not found."
raise ValueError(error_msg)
if xlo == None or xhi == None or ylo == None or yhi == None or zlo == None or zhi == None:
if (
xlo == None
or xhi == None
or ylo == None
or yhi == None
or zlo == None
or zhi == None
):
error_msg = f"Invalid box coordinates."
raise ValueError(error_msg)
if len(masses) == 0:
Expand Down
20 changes: 15 additions & 5 deletions arcann_training/common/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@

# Unittested
@catch_errors_decorator
def exclude_substring_from_string_list(input_list: List[str], substring: str) -> List[str]:
def exclude_substring_from_string_list(
input_list: List[str], substring: str
) -> List[str]:
"""
Remove all strings containing a given substring from a list of strings.
Expand Down Expand Up @@ -74,7 +76,9 @@ def exclude_substring_from_string_list(input_list: List[str], substring: str) ->

# Unittested
@catch_errors_decorator
def replace_substring_in_string_list(input_list: List[str], substring_in: str, substring_out: str) -> List[str]:
def replace_substring_in_string_list(
input_list: List[str], substring_in: str, substring_out: str
) -> List[str]:
"""
Replace a specified substring with a new substring in each string of a list.
Expand Down Expand Up @@ -110,13 +114,17 @@ def replace_substring_in_string_list(input_list: List[str], substring_in: str, s
# if not substring_out:
# raise ValueError("Invalid input. substring_out must be a non-empty string.")

output_list = [string.replace(substring_in, substring_out).strip() for string in input_list]
output_list = [
string.replace(substring_in, substring_out).strip() for string in input_list
]
return output_list


# Unittested
@catch_errors_decorator
def string_list_to_textfile(file_path: Path, string_list: List[str], read_only: bool = False) -> None:
def string_list_to_textfile(
file_path: Path, string_list: List[str], read_only: bool = False
) -> None:
"""
Write a list of strings to a text file.
Expand Down Expand Up @@ -155,7 +163,9 @@ def string_list_to_textfile(file_path: Path, string_list: List[str], read_only:
error_msg = f"'{file_path}' must be a '{type(Path(''))}'."
raise TypeError(error_msg)

if not isinstance(string_list, list) or not all(isinstance(s, str) for s in string_list):
if not isinstance(string_list, list) or not all(
isinstance(s, str) for s in string_list
):
error_msg = f"'{string_list}' must be a '{type([])}' of '{type('')}.'"
raise TypeError(error_msg)

Expand Down
11 changes: 9 additions & 2 deletions arcann_training/common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def setup_logging(verbose: int = 0) -> Dict:
},
},
"handlers": {
"console": {"class": "logging.StreamHandler", "level": "INFO", "formatter": "simple", "stream": "ext://sys.stdout"}, # Use standard output (or sys.stderr)
"console": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "simple",
"stream": "ext://sys.stdout",
}, # Use standard output (or sys.stderr)
"file": {
"class": "logging.FileHandler",
"level": "INFO",
Expand All @@ -50,7 +55,9 @@ def setup_logging(verbose: int = 0) -> Dict:
"mode": "a", # Append mode
},
},
"loggers": {"": {"handlers": ["console", "file"], "level": "INFO", "propagate": True}},
"loggers": {
"": {"handlers": ["console", "file"], "level": "INFO", "propagate": True}
},
}

if verbose >= 1:
Expand Down
Loading

0 comments on commit 31cf9ab

Please sign in to comment.