Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify #12

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 50 additions & 59 deletions src/aiida_dftk/calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import json
import typing as ty
from pathlib import Path

from aiida import orm
from aiida.common import datastructures, exceptions
Expand All @@ -15,11 +16,9 @@
class DftkCalculation(CalcJob):
"""`CalcJob` implementation for DFTK."""

_DEFAULT_PREFIX = 'DFTK'
_DEFAULT_INPUT_EXTENSION = 'json'
_DEFAULT_STDOUT_EXTENSION = 'txt'
_DEFAULT_SCFRES_SUMMARY_NAME = 'self_consistent_field.json'
_SUPPORTED_POSTSCF = ['compute_forces_cart', 'compute_stresses_cart','compute_bands']
SCFRES_SUMMARY_NAME = 'self_consistent_field.json'
# TODO: don't limit postscf
_SUPPORTED_POSTSCF = ['compute_forces_cart', 'compute_stresses_cart', 'compute_bands']
_PSEUDO_SUBFOLDER = './pseudo/'
_MIN_OUTPUT_BUFFER_TIME = 60

Expand All @@ -37,24 +36,25 @@ def define(cls, spec):
"""Define the process specification."""
super().define(spec)
# Inputs
spec.input('metadata.options.prefix', valid_type=str, default=cls._DEFAULT_PREFIX)
spec.input('metadata.options.stdout_extension', valid_type=str, default=cls._DEFAULT_STDOUT_EXTENSION)
spec.input('metadata.options.withmpi', valid_type=bool, default=True)
spec.input('metadata.options.max_wallclock_seconds', valid_type=int, default=1800)

spec.input('structure', valid_type=orm.StructureData, help='structure')
spec.input_namespace('pseudos', valid_type=UpfData, help='The pseudopotentials.', dynamic=True)
spec.input('kpoints', valid_type=orm.KpointsData, help='kpoint mesh or kpoint path')
spec.input('parameters', valid_type=orm.Dict, help='input parameters')
spec.input('settings', valid_type=orm.Dict, required=False, help='Various special settings.')
spec.input('parent_folder', valid_type=orm.RemoteData, required=False, help='A remote folder used for restarts.')

options = spec.inputs['metadata']['options']

options['parser_name'].default = 'dftk'
options['input_filename'].default = f'run_dftk.json'
options['max_wallclock_seconds'].default = 1800

# TODO: Why is this here?
options['resources'].default = {'num_machines': 1, 'num_mpiprocs_per_machine': 1}
options['input_filename'].default = f'{cls._DEFAULT_PREFIX}.{cls._DEFAULT_INPUT_EXTENSION}'
options['withmpi'].default = True

# Exit codes
# TODO: Log file should be removed in favor of using stdout. Needs a change in AiidaDFTK.jl.
# TODO: Code 100 is already used in the super class!
spec.exit_code(100, 'ERROR_MISSING_LOG_FILE', message='The output file containing DFTK logs is missing.')
spec.exit_code(101, 'ERROR_MISSING_SCFRES_FILE', message='The output file containing SCF results is missing.')
spec.exit_code(102, 'ERROR_MISSING_FORCES_FILE', message='The output file containing forces is missing.')
Expand All @@ -67,7 +67,9 @@ def define(cls, spec):

# Outputs
spec.output('output_parameters', valid_type=orm.Dict, help='output parameters')
# TODO: doesn't seem to be used?
spec.output('output_structure', valid_type=orm.Dict, required=False, help='output structure')
# TODO: doesn't seem to be used?
spec.output(
'output_kpoints', valid_type=orm.KpointsData, required=False, help='kpoints array, if generated by DFTK'
)
Expand All @@ -82,21 +84,22 @@ def define(cls, spec):

spec.default_output_node = 'output_parameters'

def validate_options(self):
def _validate_options(self):
"""Validate the options input.

Check that the wihmpi option is set to True if the number of mpiprocs is greater than 1.
Check max_wallclock_seconds is greater than the min_output_buffer_time.
"""
options = self.inputs.metadata.options
if options.withmpi is False and options.resources.get('num_mpiprocs_per_machine', 1) > 1:
# TODO: does aiida not already check this?
raise exceptions.InputValidationError('MPI is required when num_mpiprocs_per_machine > 1.')
if options.max_wallclock_seconds < self._MIN_OUTPUT_BUFFER_TIME:
raise exceptions.InputValidationError(
f'max_wallclock_seconds must be greater than {self._MIN_OUTPUT_BUFFER_TIME}.'
)

def validate_inputs(self):
def _validate_inputs(self):
"""Validate input parameters.

Check that the post-SCF function(s) are supported.
Expand All @@ -107,21 +110,20 @@ def validate_inputs(self):
if postscf['$function'] not in self._SUPPORTED_POSTSCF:
raise exceptions.InputValidationError(f"Unsupported postscf function: {postscf['$function']}")

def validate_pseudos(self):
"""Valdiate the pseudopotentials.
def _validate_pseudos(self):
"""Validate the pseudopotentials.

Check that there is a one-to-one map of kinds in the structure to pseudopotentials.
"""
kinds = [kind.name for kind in self.inputs.structure.kinds]
if set(kinds) != set(self.inputs.pseudos.keys()):
pseudos_str = ', '.join(list(self.inputs.pseudos.keys()))
kinds_str = ', '.join(list(kinds))
kinds = set(kind.name for kind in self.inputs.structure.kinds)
pseudos = set(self.inputs.pseudos.keys())
if kinds != pseudos:
raise exceptions.InputValidationError(
'Mismatch between the defined pseudos and the list of kinds of the structure.\n'
f'Pseudos: {pseudos_str};\nKinds:{kinds_str}'
f'Pseudos: {pseudos};\nKinds:{kinds}'
)

def validate_kpoints(self):
def _validate_kpoints(self):
"""Validate the k-points intput.

Check that the input k-points provide a k-points mesh.
Expand Down Expand Up @@ -159,13 +161,11 @@ def _generate_inputdata(
local_copy_pseudo_list.append((pseudo.uuid, pseudo.filename, f'{self._PSEUDO_SUBFOLDER}{pseudo.filename}'))
data['basis_kwargs']['kgrid'], data['basis_kwargs']['kshift'] = kpoints.get_kpoints_mesh()

# set the maxtime for the SCF cycle
# if max_wallclock_seconds is smaller than 600 seconds, set the maxtime as max_wallclock_seconds - MIN_OUTPUT_BUFFER_TIME
# else set the maxtime as int(0.95 * max_wallclock_seconds)
if self.inputs.metadata.options.max_wallclock_seconds < self._MIN_OUTPUT_BUFFER_TIME * 10:
maxtime = self.inputs.metadata.options.max_wallclock_seconds - self._MIN_OUTPUT_BUFFER_TIME
else:
maxtime = int(0.9 * self.inputs.metadata.options.max_wallclock_seconds)
# set the maxtime for the SCF cycle, with a margin of _MIN_OUTPUT_BUFFER_TIME and 10%, whichever leads to a larger margin
maxtime = min(
self.inputs.metadata.options.max_wallclock_seconds - self._MIN_OUTPUT_BUFFER_TIME,
0.9 * self.inputs.metadata.options.max_wallclock_seconds,
)
data['scf']['maxtime'] = maxtime

DftkCalculation._merge_dicts(data, parameters.get_dict())
Expand All @@ -178,6 +178,11 @@ def _generate_cmdline_params(self) -> ty.List[str]:
cmd_params.extend(['-e', 'using AiidaDFTK; AiidaDFTK.run()', self.metadata.options.input_filename])
return cmd_params

@staticmethod
def get_log_file(input_filename: str) -> str:
"""Gets the name of the log file based on the name of the input file."""
return Path(input_filename).stem + '.log'

def _generate_retrieve_list(self, parameters: orm.Dict) -> list:
"""Generate the list of files to retrieve based on the type of calculation requested in the input parameters.

Expand All @@ -190,10 +195,9 @@ def _generate_retrieve_list(self, parameters: orm.Dict) -> list:
f"{item['$function']}.json" if item['$function'] == 'compute_bands' else f"{item['$function']}.hdf5"
for item in parameters['postscf']
]
retrieve_list.append(f'{self._DEFAULT_PREFIX}.log')
retrieve_list.append(DftkCalculation.get_log_file(self.inputs.metadata.options.input_filename))
retrieve_list.append('timings.json')
retrieve_list.append(f'{self._DEFAULT_PREFIX}.{self._DEFAULT_STDOUT_EXTENSION}')
retrieve_list.append(f'{self._DEFAULT_SCFRES_SUMMARY_NAME}')
retrieve_list.append(f'{self.SCFRES_SUMMARY_NAME}')
return retrieve_list

def prepare_for_submission(self, folder):
Expand All @@ -203,22 +207,18 @@ def prepare_for_submission(self, folder):
the calculation.
:return: `aiida.common.datastructures.CalcInfo` instance
"""
# Process the `settings`` so that capitalization isn't an issue
settings = self.inputs.settings.get_dict()


self.validate_options()
self.validate_inputs()
self.validate_pseudos()
self.validate_kpoints()
self._validate_options()
self._validate_inputs()
self._validate_pseudos()
self._validate_kpoints()

# Create lists which specify files to copy and symlink
remote_copy_list = []
remote_symlink_list = []

# Generate the input file content
arguments = [self.inputs.parameters, self.inputs.structure, self.inputs.pseudos, self.inputs.kpoints]
input_filecontent, local_copy_list = self._generate_inputdata(*arguments)
input_filecontent, local_copy_list = self._generate_inputdata(self.inputs.parameters, self.inputs.structure, self.inputs.pseudos, self.inputs.kpoints)

# write input file
input_filename = folder.get_abs_path(self.metadata.options.input_filename)
Expand All @@ -227,25 +227,17 @@ def prepare_for_submission(self, folder):

# List the files (scfres.jld2) to copy or symlink in the case of a restart
if 'parent_folder' in self.inputs:
# Symlink by default if on the same computer, otherwise copy by default
# Symlink if on the same computer, otherwise copy
same_computer = self.inputs.code.computer.uuid == self.inputs.parent_folder.computer.uuid
if settings.pop('PARENT_FOLDER_SYMLINK', same_computer):
remote_symlink_list.append(
(
self.inputs.parent_folder.computer.uuid,
os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']),
self.inputs.parameters['scf']['checkpointfile']
)
)

checkpointfile_info = (
self.inputs.parent_folder.computer.uuid,
os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']),
self.inputs.parameters['scf']['checkpointfile']
)
if same_computer:
remote_symlink_list.append(checkpointfile_info)
else:
remote_copy_list.append(
(
self.inputs.parent_folder.computer.uuid,
os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']),
self.inputs.parameters['scf']['checkpointfile']
)
)
remote_copy_list.append(checkpointfile_info)

# prepare command line parameters
cmdline_params = self._generate_cmdline_params()
Expand All @@ -257,7 +249,6 @@ def prepare_for_submission(self, folder):
codeinfo = datastructures.CodeInfo()
codeinfo.code_uuid = self.inputs.code.uuid
codeinfo.cmdline_params = cmdline_params
codeinfo.stdout_name = f'{self._DEFAULT_PREFIX}.{self._DEFAULT_STDOUT_EXTENSION}'

# Set up the `CalcInfo` so AiiDA knows what to do with everything
calcinfo = datastructures.CalcInfo()
Expand Down
18 changes: 8 additions & 10 deletions src/aiida_dftk/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
class DftkParser(Parser):
"""`Parser` implementation for DFTK."""

# TODO: I don't like this!
_DEFAULT_SCFRES_SUMMARY_NAME = 'self_consistent_field.json'
_DEFAULT_LOG_FILE_NAME = 'DFTK.log'
# TODO: DEFAULT_ prefix should be removed. I don't think that these names can be changed.
_DEFAULT_ENERGY_UNIT = 'hartree'
_DEFAULT_FORCE_FUNCNAME = 'compute_forces_cart'
_DEFAULT_FORCE_UNIT = 'hartree/bohr'
Expand All @@ -36,17 +34,17 @@ class DftkParser(Parser):

def parse(self, **kwargs):
"""Parse DFTK output files."""
# TODO: log recovery doesn't even seem to work? :(
if self._DEFAULT_LOG_FILE_NAME not in self.retrieved.base.repository.list_object_names():
log_file_name = DftkCalculation.get_log_file(self.node.get_options()["input_filename"])
if log_file_name not in self.retrieved.base.repository.list_object_names():
return self.exit_codes.ERROR_MISSING_LOG_FILE
# TODO: how to make this log available? This unfortunately doesn't output to the process report.
# TODO: maybe DFTK could log in a way that allows us to map its log levels to aiida's
self.logger.info(self.retrieved.base.repository.get_object_content(self._DEFAULT_LOG_FILE_NAME))
self.logger.info(self.retrieved.base.repository.get_object_content(log_file_name))

# TODO: double check this if
# if ran_out_of_walltime (terminated illy)
if self.node.exit_status == DftkCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME.status:
# if _DEFAULT_SCFRES_SUMMARY_NAME is not in the list self.retrieved.list_object_names(), SCF terminated illy
if self._DEFAULT_SCFRES_SUMMARY_NAME not in self.retrieved.list_object_names():
# if SCF summary file is not in the list of retrieved files, SCF terminated illy
if DftkCalculation.SCFRES_SUMMARY_NAME not in self.retrieved.list_object_names():
return self.exit_codes.ERROR_SCF_OUT_OF_WALLTIME
# POSTSCF terminated illy
else:
Expand All @@ -55,7 +53,7 @@ def parse(self, **kwargs):
# Check retrieve list to know which files the calculation is expected to have produced.
try:
self._parse_optional_result(
self._DEFAULT_SCFRES_SUMMARY_NAME,
DftkCalculation.SCFRES_SUMMARY_NAME,
self.exit_codes.ERROR_MISSING_SCFRES_FILE,
self._parse_output_parameters,
)
Expand Down
19 changes: 5 additions & 14 deletions src/aiida_dftk/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def define(cls, spec):

spec.outline(
cls.setup,
cls.validate_parameters,
cls.validate_kpoints,
cls.validate_pseudos,
cls.validate_resources,
Expand Down Expand Up @@ -71,16 +70,7 @@ def setup(self):
self.ctx.restart_calc = None
self.ctx.inputs = AttributeDict(self.exposed_inputs(DftkCalculation, 'dftk'))

def validate_parameters(self):
"""Validate inputs that might depend on each other and cannot be validated by the spec.

Also define dictionary `inputs` in the context, that will contain the inputs for the calculation that will be
launched in the `run_calculation` step.
"""
#super().setup()
self.ctx.inputs.parameters = self.ctx.inputs.parameters.get_dict()
self.ctx.inputs.settings = self.ctx.inputs.settings.get_dict() if 'settings' in self.ctx.inputs else {}

# TODO: We probably want to handle the kpoint distance on the Julia side instead.
def validate_kpoints(self):
"""Validate the inputs related to k-points.

Expand Down Expand Up @@ -121,6 +111,7 @@ def validate_pseudos(self):
return self.exit_codes.ERROR_INVALID_INPUT_PSEUDO_POTENTIALS # pylint: disable=no-member


# TODO: This is weird. Shouldn't aiida already handle this internally?
def validate_resources(self):
"""Validate the inputs related to the resources.

Expand Down Expand Up @@ -154,8 +145,7 @@ def report_error_handled(self, calculation, action):
:param calculation: the failed calculation node
:param action: a string message with the action taken
"""
arguments = [calculation.process_label, calculation.pk, calculation.exit_status, calculation.exit_message]
self.report('{}<{}> failed with exit status {}: {}'.format(*arguments))
self.report(f'{calculation.process_label}<{calculation.pk}> failed with exit status {calculation.exit_status}: {calculation.exit_message}')
self.report(f'Action taken: {action}')

@process_handler(priority=500, exit_codes=[DftkCalculation.exit_codes.ERROR_SCF_CONVERGENCE_NOT_REACHED])
Expand All @@ -164,12 +154,13 @@ def handle_scf_convergence_not_reached(self, _):
return None

# Just as a blueprint, delete after ^ is implemented
# TODO: What exactly is this doing?
@process_handler(priority=580, exit_codes=[
DftkCalculation.exit_codes.ERROR_SCF_CONVERGENCE_NOT_REACHED,
DftkCalculation.exit_codes.ERROR_POSTSCF_OUT_OF_WALLTIME
])
def handle_recoverable_SCF_unconverged_and_POSTSCF_out_of_walltime_(self, calculation):
"""Handle `RROR_SCF_CONVERGENCE_NOT_REACHED` and `ERROR_POSTSCF_OUT_OF_WALLTIME` exit code: calculations shut down neatly and we can simply restart."""
"""Handle `ERROR_SCF_CONVERGENCE_NOT_REACHED` and `ERROR_POSTSCF_OUT_OF_WALLTIME` exit code: calculations shut down neatly and we can simply restart."""
try:
self.ctx.inputs.structure = calculation.outputs.output_structure
except exceptions.NotExistent:
Expand Down
Loading