Skip to content

Commit

Permalink
Simplify (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
Technici4n authored Nov 8, 2024
1 parent 7b7bf60 commit f3e6e68
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 83 deletions.
109 changes: 50 additions & 59 deletions src/aiida_dftk/calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import json
import typing as ty
from pathlib import Path

from aiida import orm
from aiida.common import datastructures, exceptions
Expand All @@ -15,11 +16,9 @@
class DftkCalculation(CalcJob):
"""`CalcJob` implementation for DFTK."""

_DEFAULT_PREFIX = 'DFTK'
_DEFAULT_INPUT_EXTENSION = 'json'
_DEFAULT_STDOUT_EXTENSION = 'txt'
_DEFAULT_SCFRES_SUMMARY_NAME = 'self_consistent_field.json'
_SUPPORTED_POSTSCF = ['compute_forces_cart', 'compute_stresses_cart','compute_bands']
SCFRES_SUMMARY_NAME = 'self_consistent_field.json'
# TODO: don't limit postscf
_SUPPORTED_POSTSCF = ['compute_forces_cart', 'compute_stresses_cart', 'compute_bands']
_PSEUDO_SUBFOLDER = './pseudo/'
_MIN_OUTPUT_BUFFER_TIME = 60

Expand All @@ -37,24 +36,25 @@ def define(cls, spec):
"""Define the process specification."""
super().define(spec)
# Inputs
spec.input('metadata.options.prefix', valid_type=str, default=cls._DEFAULT_PREFIX)
spec.input('metadata.options.stdout_extension', valid_type=str, default=cls._DEFAULT_STDOUT_EXTENSION)
spec.input('metadata.options.withmpi', valid_type=bool, default=True)
spec.input('metadata.options.max_wallclock_seconds', valid_type=int, default=1800)

spec.input('structure', valid_type=orm.StructureData, help='structure')
spec.input_namespace('pseudos', valid_type=UpfData, help='The pseudopotentials.', dynamic=True)
spec.input('kpoints', valid_type=orm.KpointsData, help='kpoint mesh or kpoint path')
spec.input('parameters', valid_type=orm.Dict, help='input parameters')
spec.input('settings', valid_type=orm.Dict, required=False, help='Various special settings.')
spec.input('parent_folder', valid_type=orm.RemoteData, required=False, help='A remote folder used for restarts.')

options = spec.inputs['metadata']['options']

options['parser_name'].default = 'dftk'
options['input_filename'].default = f'run_dftk.json'
options['max_wallclock_seconds'].default = 1800

# TODO: Why is this here?
options['resources'].default = {'num_machines': 1, 'num_mpiprocs_per_machine': 1}
options['input_filename'].default = f'{cls._DEFAULT_PREFIX}.{cls._DEFAULT_INPUT_EXTENSION}'
options['withmpi'].default = True

# Exit codes
# TODO: Log file should be removed in favor of using stdout. Needs a change in AiidaDFTK.jl.
# TODO: Code 100 is already used in the super class!
spec.exit_code(100, 'ERROR_MISSING_LOG_FILE', message='The output file containing DFTK logs is missing.')
spec.exit_code(101, 'ERROR_MISSING_SCFRES_FILE', message='The output file containing SCF results is missing.')
spec.exit_code(102, 'ERROR_MISSING_FORCES_FILE', message='The output file containing forces is missing.')
Expand All @@ -67,7 +67,9 @@ def define(cls, spec):

# Outputs
spec.output('output_parameters', valid_type=orm.Dict, help='output parameters')
# TODO: doesn't seem to be used?
spec.output('output_structure', valid_type=orm.Dict, required=False, help='output structure')
# TODO: doesn't seem to be used?
spec.output(
'output_kpoints', valid_type=orm.KpointsData, required=False, help='kpoints array, if generated by DFTK'
)
Expand All @@ -82,21 +84,22 @@ def define(cls, spec):

spec.default_output_node = 'output_parameters'

def validate_options(self):
def _validate_options(self):
"""Validate the options input.
Check that the wihmpi option is set to True if the number of mpiprocs is greater than 1.
Check max_wallclock_seconds is greater than the min_output_buffer_time.
"""
options = self.inputs.metadata.options
if options.withmpi is False and options.resources.get('num_mpiprocs_per_machine', 1) > 1:
# TODO: does aiida not already check this?
raise exceptions.InputValidationError('MPI is required when num_mpiprocs_per_machine > 1.')
if options.max_wallclock_seconds < self._MIN_OUTPUT_BUFFER_TIME:
raise exceptions.InputValidationError(
f'max_wallclock_seconds must be greater than {self._MIN_OUTPUT_BUFFER_TIME}.'
)

def validate_inputs(self):
def _validate_inputs(self):
"""Validate input parameters.
Check that the post-SCF function(s) are supported.
Expand All @@ -107,21 +110,20 @@ def validate_inputs(self):
if postscf['$function'] not in self._SUPPORTED_POSTSCF:
raise exceptions.InputValidationError(f"Unsupported postscf function: {postscf['$function']}")

def validate_pseudos(self):
"""Valdiate the pseudopotentials.
def _validate_pseudos(self):
"""Validate the pseudopotentials.
Check that there is a one-to-one map of kinds in the structure to pseudopotentials.
"""
kinds = [kind.name for kind in self.inputs.structure.kinds]
if set(kinds) != set(self.inputs.pseudos.keys()):
pseudos_str = ', '.join(list(self.inputs.pseudos.keys()))
kinds_str = ', '.join(list(kinds))
kinds = set(kind.name for kind in self.inputs.structure.kinds)
pseudos = set(self.inputs.pseudos.keys())
if kinds != pseudos:
raise exceptions.InputValidationError(
'Mismatch between the defined pseudos and the list of kinds of the structure.\n'
f'Pseudos: {pseudos_str};\nKinds:{kinds_str}'
f'Pseudos: {pseudos};\nKinds:{kinds}'
)

def validate_kpoints(self):
def _validate_kpoints(self):
"""Validate the k-points intput.
Check that the input k-points provide a k-points mesh.
Expand Down Expand Up @@ -159,13 +161,11 @@ def _generate_inputdata(
local_copy_pseudo_list.append((pseudo.uuid, pseudo.filename, f'{self._PSEUDO_SUBFOLDER}{pseudo.filename}'))
data['basis_kwargs']['kgrid'], data['basis_kwargs']['kshift'] = kpoints.get_kpoints_mesh()

# set the maxtime for the SCF cycle
# if max_wallclock_seconds is smaller than 600 seconds, set the maxtime as max_wallclock_seconds - MIN_OUTPUT_BUFFER_TIME
# else set the maxtime as int(0.95 * max_wallclock_seconds)
if self.inputs.metadata.options.max_wallclock_seconds < self._MIN_OUTPUT_BUFFER_TIME * 10:
maxtime = self.inputs.metadata.options.max_wallclock_seconds - self._MIN_OUTPUT_BUFFER_TIME
else:
maxtime = int(0.9 * self.inputs.metadata.options.max_wallclock_seconds)
# set the maxtime for the SCF cycle, with a margin of _MIN_OUTPUT_BUFFER_TIME and 10%, whichever leads to a larger margin
maxtime = min(
self.inputs.metadata.options.max_wallclock_seconds - self._MIN_OUTPUT_BUFFER_TIME,
0.9 * self.inputs.metadata.options.max_wallclock_seconds,
)
data['scf']['maxtime'] = maxtime

DftkCalculation._merge_dicts(data, parameters.get_dict())
Expand All @@ -178,6 +178,11 @@ def _generate_cmdline_params(self) -> ty.List[str]:
cmd_params.extend(['-e', 'using AiidaDFTK; AiidaDFTK.run()', self.metadata.options.input_filename])
return cmd_params

@staticmethod
def get_log_file(input_filename: str) -> str:
"""Gets the name of the log file based on the name of the input file."""
return Path(input_filename).stem + '.log'

def _generate_retrieve_list(self, parameters: orm.Dict) -> list:
"""Generate the list of files to retrieve based on the type of calculation requested in the input parameters.
Expand All @@ -190,10 +195,9 @@ def _generate_retrieve_list(self, parameters: orm.Dict) -> list:
f"{item['$function']}.json" if item['$function'] == 'compute_bands' else f"{item['$function']}.hdf5"
for item in parameters['postscf']
]
retrieve_list.append(f'{self._DEFAULT_PREFIX}.log')
retrieve_list.append(DftkCalculation.get_log_file(self.inputs.metadata.options.input_filename))
retrieve_list.append('timings.json')
retrieve_list.append(f'{self._DEFAULT_PREFIX}.{self._DEFAULT_STDOUT_EXTENSION}')
retrieve_list.append(f'{self._DEFAULT_SCFRES_SUMMARY_NAME}')
retrieve_list.append(f'{self.SCFRES_SUMMARY_NAME}')
return retrieve_list

def prepare_for_submission(self, folder):
Expand All @@ -203,22 +207,18 @@ def prepare_for_submission(self, folder):
the calculation.
:return: `aiida.common.datastructures.CalcInfo` instance
"""
# Process the `settings`` so that capitalization isn't an issue
settings = self.inputs.settings.get_dict()


self.validate_options()
self.validate_inputs()
self.validate_pseudos()
self.validate_kpoints()
self._validate_options()
self._validate_inputs()
self._validate_pseudos()
self._validate_kpoints()

# Create lists which specify files to copy and symlink
remote_copy_list = []
remote_symlink_list = []

# Generate the input file content
arguments = [self.inputs.parameters, self.inputs.structure, self.inputs.pseudos, self.inputs.kpoints]
input_filecontent, local_copy_list = self._generate_inputdata(*arguments)
input_filecontent, local_copy_list = self._generate_inputdata(self.inputs.parameters, self.inputs.structure, self.inputs.pseudos, self.inputs.kpoints)

# write input file
input_filename = folder.get_abs_path(self.metadata.options.input_filename)
Expand All @@ -227,25 +227,17 @@ def prepare_for_submission(self, folder):

# List the files (scfres.jld2) to copy or symlink in the case of a restart
if 'parent_folder' in self.inputs:
# Symlink by default if on the same computer, otherwise copy by default
# Symlink if on the same computer, otherwise copy
same_computer = self.inputs.code.computer.uuid == self.inputs.parent_folder.computer.uuid
if settings.pop('PARENT_FOLDER_SYMLINK', same_computer):
remote_symlink_list.append(
(
self.inputs.parent_folder.computer.uuid,
os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']),
self.inputs.parameters['scf']['checkpointfile']
)
)

checkpointfile_info = (
self.inputs.parent_folder.computer.uuid,
os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']),
self.inputs.parameters['scf']['checkpointfile']
)
if same_computer:
remote_symlink_list.append(checkpointfile_info)
else:
remote_copy_list.append(
(
self.inputs.parent_folder.computer.uuid,
os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']),
self.inputs.parameters['scf']['checkpointfile']
)
)
remote_copy_list.append(checkpointfile_info)

# prepare command line parameters
cmdline_params = self._generate_cmdline_params()
Expand All @@ -257,7 +249,6 @@ def prepare_for_submission(self, folder):
codeinfo = datastructures.CodeInfo()
codeinfo.code_uuid = self.inputs.code.uuid
codeinfo.cmdline_params = cmdline_params
codeinfo.stdout_name = f'{self._DEFAULT_PREFIX}.{self._DEFAULT_STDOUT_EXTENSION}'

# Set up the `CalcInfo` so AiiDA knows what to do with everything
calcinfo = datastructures.CalcInfo()
Expand Down
18 changes: 8 additions & 10 deletions src/aiida_dftk/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
class DftkParser(Parser):
"""`Parser` implementation for DFTK."""

# TODO: I don't like this!
_DEFAULT_SCFRES_SUMMARY_NAME = 'self_consistent_field.json'
_DEFAULT_LOG_FILE_NAME = 'DFTK.log'
# TODO: DEFAULT_ prefix should be removed. I don't think that these names can be changed.
_DEFAULT_ENERGY_UNIT = 'hartree'
_DEFAULT_FORCE_FUNCNAME = 'compute_forces_cart'
_DEFAULT_FORCE_UNIT = 'hartree/bohr'
Expand All @@ -36,17 +34,17 @@ class DftkParser(Parser):

def parse(self, **kwargs):
"""Parse DFTK output files."""
# TODO: log recovery doesn't even seem to work? :(
if self._DEFAULT_LOG_FILE_NAME not in self.retrieved.base.repository.list_object_names():
log_file_name = DftkCalculation.get_log_file(self.node.get_options()["input_filename"])
if log_file_name not in self.retrieved.base.repository.list_object_names():
return self.exit_codes.ERROR_MISSING_LOG_FILE
# TODO: how to make this log available? This unfortunately doesn't output to the process report.
# TODO: maybe DFTK could log in a way that allows us to map its log levels to aiida's
self.logger.info(self.retrieved.base.repository.get_object_content(self._DEFAULT_LOG_FILE_NAME))
self.logger.info(self.retrieved.base.repository.get_object_content(log_file_name))

# TODO: double check this if
# if ran_out_of_walltime (terminated illy)
if self.node.exit_status == DftkCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME.status:
# if _DEFAULT_SCFRES_SUMMARY_NAME is not in the list self.retrieved.list_object_names(), SCF terminated illy
if self._DEFAULT_SCFRES_SUMMARY_NAME not in self.retrieved.list_object_names():
# if SCF summary file is not in the list of retrieved files, SCF terminated illy
if DftkCalculation.SCFRES_SUMMARY_NAME not in self.retrieved.list_object_names():
return self.exit_codes.ERROR_SCF_OUT_OF_WALLTIME
# POSTSCF terminated illy
else:
Expand All @@ -55,7 +53,7 @@ def parse(self, **kwargs):
# Check retrieve list to know which files the calculation is expected to have produced.
try:
self._parse_optional_result(
self._DEFAULT_SCFRES_SUMMARY_NAME,
DftkCalculation.SCFRES_SUMMARY_NAME,
self.exit_codes.ERROR_MISSING_SCFRES_FILE,
self._parse_output_parameters,
)
Expand Down
19 changes: 5 additions & 14 deletions src/aiida_dftk/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def define(cls, spec):

spec.outline(
cls.setup,
cls.validate_parameters,
cls.validate_kpoints,
cls.validate_pseudos,
cls.validate_resources,
Expand Down Expand Up @@ -71,16 +70,7 @@ def setup(self):
self.ctx.restart_calc = None
self.ctx.inputs = AttributeDict(self.exposed_inputs(DftkCalculation, 'dftk'))

def validate_parameters(self):
"""Validate inputs that might depend on each other and cannot be validated by the spec.
Also define dictionary `inputs` in the context, that will contain the inputs for the calculation that will be
launched in the `run_calculation` step.
"""
#super().setup()
self.ctx.inputs.parameters = self.ctx.inputs.parameters.get_dict()
self.ctx.inputs.settings = self.ctx.inputs.settings.get_dict() if 'settings' in self.ctx.inputs else {}

# TODO: We probably want to handle the kpoint distance on the Julia side instead.
def validate_kpoints(self):
"""Validate the inputs related to k-points.
Expand Down Expand Up @@ -121,6 +111,7 @@ def validate_pseudos(self):
return self.exit_codes.ERROR_INVALID_INPUT_PSEUDO_POTENTIALS # pylint: disable=no-member


# TODO: This is weird. Shouldn't aiida already handle this internally?
def validate_resources(self):
"""Validate the inputs related to the resources.
Expand Down Expand Up @@ -154,8 +145,7 @@ def report_error_handled(self, calculation, action):
:param calculation: the failed calculation node
:param action: a string message with the action taken
"""
arguments = [calculation.process_label, calculation.pk, calculation.exit_status, calculation.exit_message]
self.report('{}<{}> failed with exit status {}: {}'.format(*arguments))
self.report(f'{calculation.process_label}<{calculation.pk}> failed with exit status {calculation.exit_status}: {calculation.exit_message}')
self.report(f'Action taken: {action}')

@process_handler(priority=500, exit_codes=[DftkCalculation.exit_codes.ERROR_SCF_CONVERGENCE_NOT_REACHED])
Expand All @@ -164,12 +154,13 @@ def handle_scf_convergence_not_reached(self, _):
return None

# Just as a blueprint, delete after ^ is implemented
# TODO: What exactly is this doing?
@process_handler(priority=580, exit_codes=[
DftkCalculation.exit_codes.ERROR_SCF_CONVERGENCE_NOT_REACHED,
DftkCalculation.exit_codes.ERROR_POSTSCF_OUT_OF_WALLTIME
])
def handle_recoverable_SCF_unconverged_and_POSTSCF_out_of_walltime_(self, calculation):
"""Handle `RROR_SCF_CONVERGENCE_NOT_REACHED` and `ERROR_POSTSCF_OUT_OF_WALLTIME` exit code: calculations shut down neatly and we can simply restart."""
"""Handle `ERROR_SCF_CONVERGENCE_NOT_REACHED` and `ERROR_POSTSCF_OUT_OF_WALLTIME` exit code: calculations shut down neatly and we can simply restart."""
try:
self.ctx.inputs.structure = calculation.outputs.output_structure
except exceptions.NotExistent:
Expand Down

0 comments on commit f3e6e68

Please sign in to comment.