Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 26, 2024
1 parent 8197e72 commit c06f616
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 129 deletions.
23 changes: 18 additions & 5 deletions src/aiida_sssp_workflow/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aiida.cmdline.params import options, types
from aiida.cmdline.utils import echo
from aiida.engine import ProcessBuilder, run_get_node, submit
from aiida.plugins import DataFactory, WorkflowFactory
from aiida.plugins import WorkflowFactory

from aiida_pseudo.data.pseudo.upf import UpfData
from aiida_sssp_workflow.cli import cmd_root
Expand All @@ -25,6 +25,7 @@

VerificationWorkChain = WorkflowFactory("sssp_workflow.verification")


def guess_properties_list(property: list) -> Tuple[List[str], str]:
# if the property is not specified, use the default list with all properties calculated.
# otherwise, use the specified properties.
Expand All @@ -43,21 +44,27 @@ def guess_properties_list(property: list) -> Tuple[List[str], str]:

return properties_list, extra_desc


def guess_is_convergence(properties_list: list) -> bool:
"""Check if it is a convergence test"""

return any([c for c in properties_list if c.startswith("convergence")])


def guess_is_full_convergence(properties_list: list) -> bool:
"""Check if all properties are run for convergence test"""

return len([c for c in properties_list if c.startswith("convergence")]) == len(DEFAULT_CONVERGENCE_PROPERTIES_LIST)
return len([c for c in properties_list if c.startswith("convergence")]) == len(
DEFAULT_CONVERGENCE_PROPERTIES_LIST
)


def guess_is_measure(properties_list: list) -> bool:
"""Check if it is a measure test"""

return any([c for c in properties_list if c.startswith("measure")])


def guess_is_ph(properties_list: list) -> bool:
"""Check if it has a measure test"""

Expand Down Expand Up @@ -175,15 +182,19 @@ def launch(
is_ph = guess_is_ph(properties_list)

if is_ph and not ph_code:
echo.echo_critical("ph_code must be provided since we run on it for phonon frequencies.")
echo.echo_critical(
"ph_code must be provided since we run on it for phonon frequencies."
)

if is_convergence and len(configuration) > 1:
echo.echo_critical(
"Only one configuration is allowed for convergence workflow."
)

if is_measure and not is_full_convergence:
echo.echo_warning("Full convergence tests are not run, so we use maximum cutoffs for transferability verification.")
echo.echo_warning(
"Full convergence tests are not run, so we use maximum cutoffs for transferability verification."
)

# Load the curent AiiDA profile and log to user
_profile = aiida.load_profile()
Expand Down Expand Up @@ -211,7 +222,9 @@ def launch(
clean_workdir=clean_workdir,
)

builder.metadata.label = f"({protocol} at {pw_code.computer.label} - {conf_label}) {pseudo.stem}"
builder.metadata.label = (
f"({protocol} at {pw_code.computer.label} - {conf_label}) {pseudo.stem}"
)
builder.metadata.description = f"""Calculation is run on protocol: {protocol}; on {pw_code.computer.label}; on configuration {conf_label}; on pseudo {pseudo.stem}."""

builder.pw_code = pw_code
Expand Down
1 change: 0 additions & 1 deletion src/aiida_sssp_workflow/protocol/criteria.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ standard:
bounds: [0.0, 20] # when error eta_c < 20 meV
eps: 1.0e-3
unit: meV/atom

1 change: 1 addition & 0 deletions src/aiida_sssp_workflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_default_mpi_options(
"withmpi": with_mpi,
}


def serialize_data(data):
from aiida.orm import (
AbstractCode,
Expand Down
23 changes: 12 additions & 11 deletions src/aiida_sssp_workflow/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiida_sssp_workflow.utils.pseudo import DualType, get_dual_type


def get_protocol(category: str, name: str | None=None):
def get_protocol(category: str, name: str | None = None):
"""Load and read protocol from faml file to a verbose dict
if name not set, return whole protocol."""
import_path = resources.path("aiida_sssp_workflow.protocol", f"{category}.yml")
Expand All @@ -17,24 +17,25 @@ def get_protocol(category: str, name: str | None=None):
else:
return protocol_dict

def generate_cutoff_list(protocol_name: str, element: str, pp_type: str) -> List[Tuple[int, int]]:
"""From the control protocol name, get the cutoff list
"""

def generate_cutoff_list(
protocol_name: str, element: str, pp_type: str
) -> List[Tuple[int, int]]:
"""From the control protocol name, get the cutoff list"""
match get_dual_type(pp_type, element):
case DualType.NC:
dual_type = 'nc_dual_scan'
dual_type = "nc_dual_scan"
case DualType.AUGLOW:
dual_type = 'nonnc_dual_scan'
dual_type = "nonnc_dual_scan"
case DualType.AUGHIGH:
dual_type = 'nonnc_high_dual_scan'
dual_type = "nonnc_high_dual_scan"

dual_scan_list = get_protocol('control', protocol_name)[dual_type]
dual_scan_list = get_protocol("control", protocol_name)[dual_type]
if len(dual_scan_list) > 0:
max_dual = int(max(dual_scan_list))
else:
max_dual = 8

ecutwfc_list = get_protocol('control', protocol_name)['wfc_scan']

return [(e, e*max_dual) for e in ecutwfc_list]
ecutwfc_list = get_protocol("control", protocol_name)["wfc_scan"]

return [(e, e * max_dual) for e in ecutwfc_list]
15 changes: 9 additions & 6 deletions src/aiida_sssp_workflow/utils/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,21 @@ class PseudoInfo(BaseModel):
# source_lib: str
# ...


class DualType(Enum):
NC = "nc"
AUGLOW = "charge augmentation low"
AUGHIGH = "charge augmentation high"


def get_dual_type(pp_type: str, element: str) -> DualType:
if element in HIGH_DUAL_ELEMENTS and pp_type != 'nc':
return DualType.AUGHIGH
elif pp_type == 'nc':
return DualType.NC
else:
return DualType.AUGLOW
if element in HIGH_DUAL_ELEMENTS and pp_type != "nc":
return DualType.AUGHIGH
elif pp_type == "nc":
return DualType.NC
else:
return DualType.AUGLOW


def extract_pseudo_info(pseudo_text: str) -> PseudoInfo:
"""Giving a pseudo, extract the pseudo info and return as a `PseudoInfo` object"""
Expand Down
1 change: 0 additions & 1 deletion src/aiida_sssp_workflow/workflows/convergence/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def get_builder(
if ret := is_valid_cutoff_list(cutoff_list):
raise ValueError(ret)


builder.cutoff_list = orm.List(list=cutoff_list)
builder.clean_workdir = orm.Bool(clean_workdir)

Expand Down
15 changes: 7 additions & 8 deletions src/aiida_sssp_workflow/workflows/transferability/eos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Workchain to calculate delta factor of specific psp"""

from typing import Tuple
from pathlib import Path

Expand Down Expand Up @@ -172,9 +173,7 @@ def get_pseudos(self, configuration) -> dict:

def _setup_protocol(self):
"""unzip and parse protocol parameters to context"""
protocol = get_protocol(
category="eos", name=self.inputs.protocol.value
)
protocol = get_protocol(category="eos", name=self.inputs.protocol.value)
self.ctx.protocol = protocol

@property
Expand Down Expand Up @@ -205,12 +204,12 @@ def get_builder(
) -> ProcessBuilder:
"""Return a builder to run this EOS convergence workchain"""
builder = super().get_builder(
code,
code,
pseudo,
protocol,
cutoffs,
mpi_options,
parallelization,
protocol,
cutoffs,
mpi_options,
parallelization,
clean_workdir,
)

Expand Down
Loading

0 comments on commit c06f616

Please sign in to comment.