Skip to content

Commit

Permalink
Control the number of submitted BaseWorkChains
Browse files Browse the repository at this point in the history
The user can now optionally control the number of BaseWorkChains which are submitted at a time.
This is especially useful in case one deals with multiple atoms and/or multiple q-points.
  • Loading branch information
Timo Reents committed Dec 8, 2023
1 parent de596eb commit 98d6277
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 18 deletions.
21 changes: 21 additions & 0 deletions src/aiida_quantumespresso_hp/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""General utilies."""
from __future__ import annotations

from typing import List


def set_tot_magnetization(input_parameters: dict, tot_magnetization: float) -> bool:
"""Set the total magnetization based on its value and the input parameters.
Expand Down Expand Up @@ -37,3 +39,22 @@ def is_perturb_only_atom(parameters: dict) -> int | None:
break

return match


def distribute_base_wcs(n_atoms: int, n_total: int) -> List[int]:
"""Distribute the number of q-point base workchains to be launched over the number of atoms.
:param n_atoms: The number of atoms.
:param n_total: The number of base workchains to be launched.
:return: The number of base workchains to be launched for each atom.
"""
quotient = n_total // n_atoms
remainder = n_total % n_atoms
n_distributed = [quotient] * n_atoms

for i in range(remainder):
n_distributed[i] += 1

n_distributed = [x for x in n_distributed if x != 0]

return n_distributed
5 changes: 5 additions & 0 deletions src/aiida_quantumespresso_hp/workflows/hp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def define(cls, spec):
'for any non-periodic directions.')
spec.input('parallelize_atoms', valid_type=orm.Bool, default=lambda: orm.Bool(False))
spec.input('parallelize_qpoints', valid_type=orm.Bool, default=lambda: orm.Bool(False))
spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False)
spec.outline(
cls.validate_qpoints,
if_(cls.should_parallelize_atoms)(
Expand Down Expand Up @@ -106,6 +107,8 @@ def get_builder_from_protocol(cls, code, protocol=None, parent_scf_folder=None,
data['parallelize_atoms'] = orm.Bool(inputs['parallelize_atoms'])
if 'parallelize_qpoints' in inputs:
data['parallelize_qpoints'] = orm.Bool(inputs['parallelize_qpoints'])
if 'max_concurrent_base_workchains' in inputs:
data['max_concurrent_base_workchains'] = orm.Int(inputs['max_concurrent_base_workchains'])

builder = cls.get_builder()
builder._data = data # pylint: disable=protected-access
Expand Down Expand Up @@ -163,6 +166,8 @@ def run_parallel_workchain(self):
inputs.clean_workdir = self.inputs.clean_workdir
inputs.parallelize_qpoints = self.inputs.parallelize_qpoints
inputs.hp.qpoints = self.ctx.qpoints
if 'max_concurrent_base_workchains' in self.inputs:
inputs.max_concurrent_base_workchains = self.inputs.max_concurrent_base_workchains
running = self.submit(HpParallelizeAtomsWorkChain, **inputs)
self.report(f'running in parallel, launching HpParallelizeAtomsWorkChain<{running.pk}>')
return ToContext(workchain=running)
Expand Down
33 changes: 24 additions & 9 deletions src/aiida_quantumespresso_hp/workflows/hp/parallelize_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
"""Work chain to launch a Quantum Espresso hp.x calculation parallelizing over the Hubbard atoms."""
from aiida import orm
from aiida.common import AttributeDict
from aiida.engine import WorkChain
from aiida.engine import WorkChain, while_
from aiida.plugins import CalculationFactory, WorkflowFactory

from aiida_quantumespresso_hp.utils.general import distribute_base_wcs

PwCalculation = CalculationFactory('quantumespresso.pw')
HpCalculation = CalculationFactory('quantumespresso.hp')
HpBaseWorkChain = WorkflowFactory('quantumespresso.hp.base')
Expand All @@ -21,12 +23,15 @@ def define(cls, spec):
super().define(spec)
spec.expose_inputs(HpBaseWorkChain, exclude=('only_initialization', 'clean_workdir'))
spec.input('parallelize_qpoints', valid_type=orm.Bool, default=lambda: orm.Bool(False))
spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False)
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.')
spec.outline(
cls.run_init,
cls.inspect_init,
cls.run_atoms,
while_(cls.should_run_atoms)(
cls.run_atoms,
),
cls.inspect_atoms,
cls.run_final,
cls.inspect_final,
Expand Down Expand Up @@ -66,18 +71,27 @@ def inspect_init(self):
self.report(f'initialization work chain {workchain} failed with status {workchain.exit_status}, aborting.')
return self.exit_codes.ERROR_INITIALIZATION_WORKCHAIN_FAILED

def run_atoms(self):
"""Run a separate `HpBaseWorkChain` for each of the defined Hubbard atoms."""
workchain = self.ctx.initialization

output_params = workchain.outputs.parameters.get_dict()
hubbard_sites = output_params['hubbard_sites']
self.ctx.hubbard_sites = list(output_params['hubbard_sites'].items())

def should_run_atoms(self):
"""Return whether there are more atoms to run."""
return len(self.ctx.hubbard_sites) > 0

def run_atoms(self):
"""Run a separate `HpBaseWorkChain` for each of the defined Hubbard atoms."""
parallelize_qpoints = self.inputs.parallelize_qpoints.value
workflow = HpParallelizeQpointsWorkChain if parallelize_qpoints else HpBaseWorkChain

for site_index, site_kind in hubbard_sites.items():
n_base_parallel = [-1] * len(self.ctx.hubbard_sites)
if 'max_concurrent_base_workchains' in self.inputs:
n_base_parallel = distribute_base_wcs(
len(self.ctx.hubbard_sites), self.inputs.max_concurrent_base_workchains.value
)

self.report(f'{n_base_parallel}')
for n_q in n_base_parallel:
site_index, site_kind = self.ctx.hubbard_sites.pop(0)
do_only_key = f'perturb_only_atom({site_index})'
key = f'atom_{site_index}'

Expand All @@ -87,7 +101,8 @@ def run_atoms(self):
inputs.hp.parameters['INPUTHP'][do_only_key] = True
inputs.hp.parameters = orm.Dict(dict=inputs.hp.parameters)
inputs.metadata.call_link_label = key

if parallelize_qpoints and n_q != -1:
inputs.max_concurrent_base_workchains = orm.Int(n_q)
node = self.submit(workflow, **inputs)
self.to_context(**{key: node})
name = workflow.__name__
Expand Down
21 changes: 14 additions & 7 deletions src/aiida_quantumespresso_hp/workflows/hp/parallelize_qpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Work chain to launch a Quantum Espresso hp.x calculation parallelizing over the Hubbard atoms."""
from aiida import orm
from aiida.common import AttributeDict
from aiida.engine import WorkChain
from aiida.engine import WorkChain, while_
from aiida.plugins import CalculationFactory, WorkflowFactory

from aiida_quantumespresso_hp.utils.general import is_perturb_only_atom
Expand All @@ -29,12 +29,15 @@ def define(cls, spec):
# yapf: disable
super().define(spec)
spec.expose_inputs(HpBaseWorkChain, exclude=('only_initialization', 'clean_workdir'))
spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False)
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.')
spec.outline(
cls.run_init,
cls.inspect_init,
cls.run_qpoints,
while_(cls.should_run_qpoints)(
cls.run_qpoints,
),
cls.inspect_qpoints,
cls.run_final,
cls.results
Expand Down Expand Up @@ -75,14 +78,18 @@ def inspect_init(self):
self.report(f'initialization work chain {workchain} failed with status {workchain.exit_status}, aborting.')
return self.exit_codes.ERROR_INITIALIZATION_WORKCHAIN_FAILED

def run_qpoints(self):
"""Run a separate `HpBaseWorkChain` for each of the q points."""
workchain = self.ctx.initialization
self.ctx.qpoints = list(range(workchain.outputs.parameters.dict.number_of_qpoints))

number_of_qpoints = workchain.outputs.parameters.dict.number_of_qpoints
def should_run_qpoints(self):
"""Return whether there are more q points to run."""
return len(self.ctx.qpoints) > 0

for qpoint_index in range(number_of_qpoints):
def run_qpoints(self):
"""Run a separate `HpBaseWorkChain` for each of the q points."""
n_base_parallel = self.inputs.max_concurrent_base_workchains.value if 'max_concurrent_base_workchains' in self.inputs else len(self.ctx.qpoints)

for _ in self.ctx.qpoints[:n_base_parallel]:
qpoint_index = self.ctx.qpoints.pop(0)
key = f'qpoint_{qpoint_index + 1}' # to keep consistency with QE
inputs = AttributeDict(self.exposed_inputs(HpBaseWorkChain))
inputs.clean_workdir = self.inputs.clean_workdir
Expand Down
12 changes: 12 additions & 0 deletions tests/utils/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,15 @@ def test_is_perturb_only_atom():

parameters = {'perturb_only_atom(1)': False}
assert is_perturb_only_atom(parameters) is None


def test_distribute_base_wcs():
"""Test the `distribute_base_wcs` function."""
from aiida_quantumespresso_hp.utils.general import distribute_base_wcs

assert distribute_base_wcs(1, 1) == [1]
assert distribute_base_wcs(1, 2) == [2]
assert distribute_base_wcs(2, 1) == [1]
assert distribute_base_wcs(2, 2) == [1, 1]
assert distribute_base_wcs(2, 3) == [2, 1]
assert distribute_base_wcs(7, 5) == [1] * 5
6 changes: 4 additions & 2 deletions tests/workflows/hp/test_parallelize_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_run_atoms(generate_workchain_atoms, generate_hp_workchain_node):
"""Test `HpParallelizeAtomsWorkChain.run_atoms`."""
process = generate_workchain_atoms()
process.ctx.initialization = generate_hp_workchain_node()

output_params = process.ctx.initialization.outputs.parameters.get_dict()
process.ctx.hubbard_sites = list(output_params['hubbard_sites'].items())
process.run_atoms()

assert 'atom_1' in process.ctx
Expand All @@ -81,7 +82,8 @@ def test_run_atoms_with_qpoints(generate_workchain_atoms, generate_hp_workchain_
"""Test `HpParallelizeAtomsWorkChain.run_atoms` with q point parallelization."""
process = generate_workchain_atoms()
process.ctx.initialization = generate_hp_workchain_node()

output_params = process.ctx.initialization.outputs.parameters.get_dict()
process.ctx.hubbard_sites = list(output_params['hubbard_sites'].items())
process.run_atoms()

# Don't know how to test something like the following
Expand Down
1 change: 1 addition & 0 deletions tests/workflows/hp/test_parallelize_qpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_run_qpoints(generate_workchain_qpoints, generate_hp_workchain_node):
"""Test `HpParallelizeQpointsWorkChain.run_qpoints`."""
process = generate_workchain_qpoints()
process.ctx.initialization = generate_hp_workchain_node()
process.ctx.qpoints = list(range(process.ctx.initialization.outputs.parameters.dict.number_of_qpoints))

process.run_qpoints()
# to keep consistency with QE we start from 1
Expand Down
2 changes: 2 additions & 0 deletions tests/workflows/protocols/test_hubbard/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ relax:
max_wallclock_seconds: 43200
resources:
num_machines: 1
num_mpiprocs_per_machine: 1
withmpi: true
parameters:
CELL:
Expand Down Expand Up @@ -69,6 +70,7 @@ scf:
max_wallclock_seconds: 43200
resources:
num_machines: 1
num_mpiprocs_per_machine: 1
withmpi: true
parameters:
CONTROL:
Expand Down

0 comments on commit 98d6277

Please sign in to comment.