diff --git a/src/aiida_quantumespresso_hp/utils/general.py b/src/aiida_quantumespresso_hp/utils/general.py index 99c9fd0..8e97b0c 100644 --- a/src/aiida_quantumespresso_hp/utils/general.py +++ b/src/aiida_quantumespresso_hp/utils/general.py @@ -2,6 +2,8 @@ """General utilies.""" from __future__ import annotations +from typing import List + def set_tot_magnetization(input_parameters: dict, tot_magnetization: float) -> bool: """Set the total magnetization based on its value and the input parameters. @@ -37,3 +39,22 @@ def is_perturb_only_atom(parameters: dict) -> int | None: break return match + + +def distribute_base_wcs(n_atoms: int, n_total: int) -> List[int]: + """Distribute the number of q-point base workchains to be launched over the number of atoms. + + :param n_atoms: The number of atoms. + :param n_total: The number of base workchains to be launched. + :return: The number of base workchains to be launched for each atom. + """ + quotient = n_total // n_atoms + remainder = n_total % n_atoms + n_distributed = [quotient] * n_atoms + + for i in range(remainder): + n_distributed[i] += 1 + + n_distributed = [x for x in n_distributed if x != 0] + + return n_distributed diff --git a/src/aiida_quantumespresso_hp/workflows/hp/main.py b/src/aiida_quantumespresso_hp/workflows/hp/main.py index 6b4a689..577b7fb 100644 --- a/src/aiida_quantumespresso_hp/workflows/hp/main.py +++ b/src/aiida_quantumespresso_hp/workflows/hp/main.py @@ -48,6 +48,7 @@ def define(cls, spec): 'for any non-periodic directions.') spec.input('parallelize_atoms', valid_type=orm.Bool, default=lambda: orm.Bool(False)) spec.input('parallelize_qpoints', valid_type=orm.Bool, default=lambda: orm.Bool(False)) + spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False) spec.outline( cls.validate_qpoints, if_(cls.should_parallelize_atoms)( @@ -106,6 +107,8 @@ def get_builder_from_protocol(cls, code, protocol=None, parent_scf_folder=None, data['parallelize_atoms'] = orm.Bool(inputs['parallelize_atoms']) if 'parallelize_qpoints' in inputs: data['parallelize_qpoints'] = orm.Bool(inputs['parallelize_qpoints']) + if 'max_concurrent_base_workchains' in inputs: + data['max_concurrent_base_workchains'] = orm.Int(inputs['max_concurrent_base_workchains']) builder = cls.get_builder() builder._data = data # pylint: disable=protected-access @@ -163,6 +166,8 @@ def run_parallel_workchain(self): inputs.clean_workdir = self.inputs.clean_workdir inputs.parallelize_qpoints = self.inputs.parallelize_qpoints inputs.hp.qpoints = self.ctx.qpoints + if 'max_concurrent_base_workchains' in self.inputs: + inputs.max_concurrent_base_workchains = self.inputs.max_concurrent_base_workchains running = self.submit(HpParallelizeAtomsWorkChain, **inputs) self.report(f'running in parallel, launching HpParallelizeAtomsWorkChain<{running.pk}>') return ToContext(workchain=running) diff --git a/src/aiida_quantumespresso_hp/workflows/hp/parallelize_atoms.py b/src/aiida_quantumespresso_hp/workflows/hp/parallelize_atoms.py index eef2c33..ba2d822 100644 --- a/src/aiida_quantumespresso_hp/workflows/hp/parallelize_atoms.py +++ b/src/aiida_quantumespresso_hp/workflows/hp/parallelize_atoms.py @@ -2,9 +2,11 @@ """Work chain to launch a Quantum Espresso hp.x calculation parallelizing over the Hubbard atoms.""" from aiida import orm from aiida.common import AttributeDict -from aiida.engine import WorkChain +from aiida.engine import WorkChain, while_ from aiida.plugins import CalculationFactory, WorkflowFactory +from aiida_quantumespresso_hp.utils.general import distribute_base_wcs + PwCalculation = CalculationFactory('quantumespresso.pw') HpCalculation = CalculationFactory('quantumespresso.hp') HpBaseWorkChain = WorkflowFactory('quantumespresso.hp.base') @@ -21,12 +23,15 @@ def define(cls, spec): super().define(spec) spec.expose_inputs(HpBaseWorkChain, exclude=('only_initialization', 'clean_workdir')) spec.input('parallelize_qpoints', valid_type=orm.Bool, default=lambda: orm.Bool(False)) + spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False) spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False), help='If `True`, work directories of all called calculation will be cleaned at the end of execution.') spec.outline( cls.run_init, cls.inspect_init, - cls.run_atoms, + while_(cls.should_run_atoms)( + cls.run_atoms, + ), cls.inspect_atoms, cls.run_final, cls.inspect_final, @@ -66,18 +71,27 @@ def inspect_init(self): self.report(f'initialization work chain {workchain} failed with status {workchain.exit_status}, aborting.') return self.exit_codes.ERROR_INITIALIZATION_WORKCHAIN_FAILED - def run_atoms(self): - """Run a separate `HpBaseWorkChain` for each of the defined Hubbard atoms.""" - workchain = self.ctx.initialization - output_params = workchain.outputs.parameters.get_dict() - hubbard_sites = output_params['hubbard_sites'] + self.ctx.hubbard_sites = list(output_params['hubbard_sites'].items()) + def should_run_atoms(self): + """Return whether there are more atoms to run.""" + return len(self.ctx.hubbard_sites) > 0 + + def run_atoms(self): + """Run a separate `HpBaseWorkChain` for each of the defined Hubbard atoms.""" parallelize_qpoints = self.inputs.parallelize_qpoints.value workflow = HpParallelizeQpointsWorkChain if parallelize_qpoints else HpBaseWorkChain - for site_index, site_kind in hubbard_sites.items(): + n_base_parallel = [-1] * len(self.ctx.hubbard_sites) + if 'max_concurrent_base_workchains' in self.inputs: + n_base_parallel = distribute_base_wcs( + len(self.ctx.hubbard_sites), self.inputs.max_concurrent_base_workchains.value + ) + self.report(f'{n_base_parallel}') + for n_q in n_base_parallel: + site_index, site_kind = self.ctx.hubbard_sites.pop(0) do_only_key = f'perturb_only_atom({site_index})' key = f'atom_{site_index}' @@ -87,7 +101,8 @@ def run_atoms(self): inputs.hp.parameters['INPUTHP'][do_only_key] = True inputs.hp.parameters = orm.Dict(dict=inputs.hp.parameters) inputs.metadata.call_link_label = key - + if parallelize_qpoints and n_q != -1: + inputs.max_concurrent_base_workchains = orm.Int(n_q) node = self.submit(workflow, **inputs) self.to_context(**{key: node}) name = workflow.__name__ diff --git a/src/aiida_quantumespresso_hp/workflows/hp/parallelize_qpoints.py b/src/aiida_quantumespresso_hp/workflows/hp/parallelize_qpoints.py index fd00ca9..993ed3b 100644 --- a/src/aiida_quantumespresso_hp/workflows/hp/parallelize_qpoints.py +++ b/src/aiida_quantumespresso_hp/workflows/hp/parallelize_qpoints.py @@ -2,7 +2,7 @@ """Work chain to launch a Quantum Espresso hp.x calculation parallelizing over the Hubbard atoms.""" from aiida import orm from aiida.common import AttributeDict -from aiida.engine import WorkChain +from aiida.engine import WorkChain, while_ from aiida.plugins import CalculationFactory, WorkflowFactory from aiida_quantumespresso_hp.utils.general import is_perturb_only_atom @@ -29,12 +29,15 @@ def define(cls, spec): # yapf: disable super().define(spec) spec.expose_inputs(HpBaseWorkChain, exclude=('only_initialization', 'clean_workdir')) + spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False) spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False), help='If `True`, work directories of all called calculation will be cleaned at the end of execution.') spec.outline( cls.run_init, cls.inspect_init, - cls.run_qpoints, + while_(cls.should_run_qpoints)( + cls.run_qpoints, + ), cls.inspect_qpoints, cls.run_final, cls.results @@ -75,14 +78,18 @@ def inspect_init(self): self.report(f'initialization work chain {workchain} failed with status {workchain.exit_status}, aborting.') return self.exit_codes.ERROR_INITIALIZATION_WORKCHAIN_FAILED - def run_qpoints(self): - """Run a separate `HpBaseWorkChain` for each of the q points.""" - workchain = self.ctx.initialization + self.ctx.qpoints = list(range(workchain.outputs.parameters.dict.number_of_qpoints)) - number_of_qpoints = workchain.outputs.parameters.dict.number_of_qpoints + def should_run_qpoints(self): + """Return whether there are more q points to run.""" + return len(self.ctx.qpoints) > 0 - for qpoint_index in range(number_of_qpoints): + def run_qpoints(self): + """Run a separate `HpBaseWorkChain` for each of the q points.""" + n_base_parallel = self.inputs.max_concurrent_base_workchains.value if 'max_concurrent_base_workchains' in self.inputs else len(self.ctx.qpoints) + for _ in self.ctx.qpoints[:n_base_parallel]: + qpoint_index = self.ctx.qpoints.pop(0) key = f'qpoint_{qpoint_index + 1}' # to keep consistency with QE inputs = AttributeDict(self.exposed_inputs(HpBaseWorkChain)) inputs.clean_workdir = self.inputs.clean_workdir diff --git a/tests/utils/test_general.py b/tests/utils/test_general.py index c7e3567..865e307 100644 --- a/tests/utils/test_general.py +++ b/tests/utils/test_general.py @@ -31,3 +31,15 @@ def test_is_perturb_only_atom(): parameters = {'perturb_only_atom(1)': False} assert is_perturb_only_atom(parameters) is None + + +def test_distribute_base_wcs(): + """Test the `distribute_base_wcs` function.""" + from aiida_quantumespresso_hp.utils.general import distribute_base_wcs + + assert distribute_base_wcs(1, 1) == [1] + assert distribute_base_wcs(1, 2) == [2] + assert distribute_base_wcs(2, 1) == [1] + assert distribute_base_wcs(2, 2) == [1, 1] + assert distribute_base_wcs(2, 3) == [2, 1] + assert distribute_base_wcs(7, 5) == [1] * 5 diff --git a/tests/workflows/hp/test_parallelize_atoms.py b/tests/workflows/hp/test_parallelize_atoms.py index e19e328..895b552 100644 --- a/tests/workflows/hp/test_parallelize_atoms.py +++ b/tests/workflows/hp/test_parallelize_atoms.py @@ -69,7 +69,8 @@ def test_run_atoms(generate_workchain_atoms, generate_hp_workchain_node): """Test `HpParallelizeAtomsWorkChain.run_atoms`.""" process = generate_workchain_atoms() process.ctx.initialization = generate_hp_workchain_node() - + output_params = process.ctx.initialization.outputs.parameters.get_dict() + process.ctx.hubbard_sites = list(output_params['hubbard_sites'].items()) process.run_atoms() assert 'atom_1' in process.ctx @@ -81,7 +82,8 @@ def test_run_atoms_with_qpoints(generate_workchain_atoms, generate_hp_workchain_ """Test `HpParallelizeAtomsWorkChain.run_atoms` with q point parallelization.""" process = generate_workchain_atoms() process.ctx.initialization = generate_hp_workchain_node() - + output_params = process.ctx.initialization.outputs.parameters.get_dict() + process.ctx.hubbard_sites = list(output_params['hubbard_sites'].items()) process.run_atoms() # Don't know how to test something like the following diff --git a/tests/workflows/hp/test_parallelize_qpoints.py b/tests/workflows/hp/test_parallelize_qpoints.py index 00910d2..d50319b 100644 --- a/tests/workflows/hp/test_parallelize_qpoints.py +++ b/tests/workflows/hp/test_parallelize_qpoints.py @@ -73,6 +73,7 @@ def test_run_qpoints(generate_workchain_qpoints, generate_hp_workchain_node): """Test `HpParallelizeQpointsWorkChain.run_qpoints`.""" process = generate_workchain_qpoints() process.ctx.initialization = generate_hp_workchain_node() + process.ctx.qpoints = list(range(process.ctx.initialization.outputs.parameters.dict.number_of_qpoints)) process.run_qpoints() # to keep consistency with QE we start from 1 diff --git a/tests/workflows/protocols/test_hubbard/test_default.yml b/tests/workflows/protocols/test_hubbard/test_default.yml index ee18383..fb2575e 100644 --- a/tests/workflows/protocols/test_hubbard/test_default.yml +++ b/tests/workflows/protocols/test_hubbard/test_default.yml @@ -30,6 +30,7 @@ relax: max_wallclock_seconds: 43200 resources: num_machines: 1 + num_mpiprocs_per_machine: 1 withmpi: true parameters: CELL: @@ -69,6 +70,7 @@ scf: max_wallclock_seconds: 43200 resources: num_machines: 1 + num_mpiprocs_per_machine: 1 withmpi: true parameters: CONTROL: