diff --git a/README.md b/README.md index 0de5a8f..2411ad4 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,21 @@ To manage the repo we use [hatch]{.title-ref} please install it ``` bash pip install hatch +hatch shell # activate shell as dev environment hatch test # run tests hatch fmt # run formatting hatch run docs:build # build docs hatch run docs:serve # live preview of doc for development ``` +### Tests +``` bash +pip install hatch +verdi devel launch-add # creates required codes +verdi presto +hatch test +``` + ## Resources - diff --git a/pyproject.toml b/pyproject.toml index 3675b1e..4f8aa62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,12 +23,14 @@ classifiers = [ 'Topic :: Scientific/Engineering :: Atmospheric Science', ] keywords = ["wc", "workflow"," icon", "aiida", "aiida-workgraph"] -requires-python = '>=3.12' +requires-python = '~=3.12' dependencies = [ "isoduration", "pydantic", "pydantic-yaml", - "aiida-core>=2.5" + "aiida-core>=2.5", + 'aiida-workgraph[widget]==0.3.14', + 'node_graph==0.0.11', ] [tool.pytest.ini_options] @@ -55,7 +57,11 @@ ignore = [ [tool.hatch.version] path = "src/wcflow/__init__.py" +[project.optional-dependencies] +tests = ["pytest"] + [tool.hatch.envs.hatch-test] +extras = ["tests"] extra-dependencies = [ "ipdb" ] diff --git a/src/wcflow/_utils.py b/src/wcflow/_utils.py new file mode 100644 index 0000000..c4ca7af --- /dev/null +++ b/src/wcflow/_utils.py @@ -0,0 +1,19 @@ +import aiida.orm + + +class OrmUtils: + @staticmethod + def convert_to_class(type_: str) -> aiida.orm.utils.node.AbstractNodeMeta: + if type_ == "file": + return aiida.orm.SinglefileData + elif type_ == "dir": + return aiida.orm.FolderData + elif type_ == "int": + return aiida.orm.Int + elif type_ == "float": + return aiida.orm.Float + elif type_ == "remote": + return aiida.orm.RemoteData + else: + raise ValueError(f"Type {type_} is unknown.") + diff --git a/src/wcflow/core.py b/src/wcflow/core.py index 0458a99..1195c92 100644 --- a/src/wcflow/core.py +++ b/src/wcflow/core.py @@ -2,7 +2,6 @@ from datetime import datetime from itertools import chain -from os.path import expandvars from pathlib import Path from typing import TYPE_CHECKING @@ -24,20 +23,25 @@ def __init__( lag: list[Duration], date: list[datetime], arg_option: str | None, + port_name: str | None, ): self._name = name self._src = src - self._path = Path(expandvars(self._src)) self._type = type - if self._type not in ["file", "dir"]: - msg = f"Data type {self._type!r} not supported. Please use 'file' or 'dir'." - raise ValueError(msg) - if len(lag) > 0 and len(date) > 0: - msg = "Either 'lag' or 'date' can be nonempty. Not both." - raise ValueError(msg) + # type probably not needed + #if self._type not in ["file", "dir", ""]: + # msg = f"Data type {self._type!r} not supported. Please use 'file' or 'dir'." + # raise ValueError(msg) + + # check if not already done by parsing? best inherit these + #if len(lag) > 0 and len(date) > 0: + # msg = "Either 'lag' or 'date' can be nonempty. Not both." + # raise ValueError(msg) + + self._port_name = self._name if port_name is None else port_name # COMMENT I think we should just disallow multiple lags, and enforce the user to write multiple lags # I am not sure how this work with icon as it does not need positional arguments @@ -67,10 +71,6 @@ def type(self) -> str: def src(self) -> str: return self._src - @property - def path(self) -> Path: - return self._path - @property def lag(self) -> list[Duration]: return self._lag @@ -83,6 +83,10 @@ def date(self) -> list[datetime]: def arg_option(self) -> str | None: return self._arg_option + @property + def port_name(self) -> str: + return self._port_name + class Data(_DataBase): def __init__( @@ -93,8 +97,9 @@ def __init__( lag: list[Duration], date: list[datetime], arg_option: str | None, + port_name: str | None ): - super().__init__(name, type, src, lag, date, arg_option) + super().__init__(name, type, src, lag, date, arg_option, port_name) self._task: Task | None = None def unroll(self, unrolled_task: UnrolledTask) -> Generator[UnrolledData, None, None]: @@ -130,7 +135,6 @@ def task(self, task: Task): raise ValueError(msg) self._task = task - class UnrolledData(_DataBase): """ Data that are created during the unrolling of a cycle. @@ -139,7 +143,7 @@ class UnrolledData(_DataBase): @classmethod def from_data(cls, data: Data, unrolled_task: UnrolledTask, unrolled_date: datetime): - return cls(unrolled_task, unrolled_date, data.name, data.type, data.src, data.lag, data.date, data.arg_option) + return cls(unrolled_task, unrolled_date, data.name, data.type, data.src, data.lag, data.date, data.arg_option, data.port_name) def __init__( self, @@ -151,8 +155,9 @@ def __init__( lag: list[Duration], date: list[datetime], arg_option: str | None, + port_name: str ): - super().__init__(name, type, src, lag, date, arg_option) + super().__init__(name, type, src, lag, date, arg_option, port_name) self._unrolled_task = unrolled_task self._unrolled_date = unrolled_date @@ -323,13 +328,19 @@ def __init__( outputs: list[Data], depends: list[Dependency], command_option: str | None, + plugin: str, + code: str, + computer: str, ): self._name = name - self._command = expandvars(command) + self._command = command self._inputs = inputs self._outputs = outputs self._depends = depends self._command_option = command_option + self._plugin = plugin + self._code = code + self._computer = computer @property def name(self) -> str: @@ -355,6 +366,18 @@ def command_option(self) -> str | None: def depends(self) -> list[Dependency]: return self._depends + @property + def plugin(self): + return self._plugin + + @property + def code(self): + return self._code + + @property + def computer(self): + return self._computer + class Task(_TaskBase): """A task that is created during the unrolling of a cycle.""" @@ -367,8 +390,11 @@ def __init__( outputs: list[Data], depends: list[Dependency], command_option: str | None, + plugin: str, + code: str, + computer: str, ): - super().__init__(name, command, inputs, outputs, depends, command_option) + super().__init__(name, command, inputs, outputs, depends, command_option, plugin, code, computer) for input_ in inputs: input_.task = self for output in outputs: @@ -409,7 +435,7 @@ class UnrolledTask(_TaskBase): @classmethod def from_task(cls, task: Task, unrolled_cycle: UnrolledCycle): return cls( - unrolled_cycle, task.name, task.command, task.inputs, task.outputs, task.depends, task.command_option + unrolled_cycle, task.name, task.command, task.inputs, task.outputs, task.depends, task.command_option, task.plugin, task.code, task.computer ) def __init__( @@ -421,8 +447,11 @@ def __init__( outputs: list[Data], depends: list[Dependency], command_option: str | None, + plugin: str, + code: str, + computer: str, ): - super().__init__(name, command, inputs, outputs, depends, command_option) + super().__init__(name, command, inputs, outputs, depends, command_option, plugin, code, computer) self._unrolled_cycle = unrolled_cycle self._unrolled_inputs = list(self.unroll_inputs()) self._unrolled_outputs = list(self.unroll_outputs()) @@ -624,9 +653,10 @@ def workflow(self) -> Workflow: class Workflow: - def __init__(self, name: str, cycles: list[Cycle]): + def __init__(self, name: str, cycles: list[Cycle], computer: str): self._name = name self._cycles = cycles + self._computer = computer for cycle in self._cycles: cycle.workflow = self self._validate_cycles() @@ -659,6 +689,10 @@ def name(self) -> str: def cycles(self) -> list[Cycle]: return self._cycles + @property + def computer(self) -> str: + return self._computer + def is_available_on_init(self, data: UnrolledData) -> bool: """Determines if the data is available on init of the workflow.""" diff --git a/src/wcflow/parsing/_yaml_data_models.py b/src/wcflow/parsing/_yaml_data_models.py index 3a2cf9f..c29323d 100644 --- a/src/wcflow/parsing/_yaml_data_models.py +++ b/src/wcflow/parsing/_yaml_data_models.py @@ -1,6 +1,5 @@ from __future__ import annotations -import time from datetime import datetime from os.path import expandvars from pathlib import Path @@ -83,17 +82,14 @@ class ConfigTask(_NamedBaseModel): To create an instance of a task defined in a workflow file """ - command: str - command_option: str | None = None - host: str | None = None - account: str | None = None + # aiida-shell specifics + command: str | None = None + # these are global commands_options, I think we can remove arg_options to simplify + command_options: str | None = None + # general plugin: str | None = None - config: str | None = None - uenv: dict | None = None - nodes: int | None = None - walltime: str | None = None - src: str | None = None - conda_env: str | None = None + computer: str | None = None + code: str | None = None def __init__(self, /, **data): # We have to treat root special as it does not typically define a command @@ -107,31 +103,38 @@ def expand_env_vars(cls, value: str) -> str: """Expands any environment variables in the value""" return expandvars(value) - @field_validator("walltime") - @classmethod - def convert_to_struct_time(cls, value: str | None) -> time.struct_time | None: - """Converts a string of form "%H:%M:%S" to a time.time_struct""" - return None if value is None else time.strptime(value, "%H:%M:%S") - class ConfigData(_NamedBaseModel): """ To create an instance of a data defined in a workflow file. """ - type: str - src: str + type: str | None = None + src: str | int | dict | None = None format: str | None = None + computer: str | None = None @field_validator("type") @classmethod def is_file_or_dir(cls, value: str) -> str: """.""" - if value not in ["file", "dir"]: - msg = "Must be one of 'file' or 'dir'." - raise ValueError(msg) + # Is this actually needed? We can refer everything from the plugin + #if value not in ["file", "dir", "int"]: + # msg = "Must be one of 'file' or 'dir'." + # raise ValueError(msg) return value + @field_validator("src") + @classmethod + def expand_env_vars(cls, value: str | int | dict | None) -> str | int | dict | None: + """Expands any environment variables in the value""" + if isinstance(value, str): + return expandvars(value) + elif isinstance(value, dict): + raise NotImplementedError + else: + return value + class ConfigCycleTaskDepend(_NamedBaseModel, _LagDateBaseModel): """ @@ -142,6 +145,7 @@ class ConfigCycleTaskDepend(_NamedBaseModel, _LagDateBaseModel): cycle_name: str | None = None +# NOT NEEDED, but can we make dict with arbitrary class ConfigCycleTaskInput(_NamedBaseModel, _LagDateBaseModel): """ To create an instance of an input in a task in a cycle defined in a workflow file. @@ -156,13 +160,15 @@ class ConfigCycleTaskInput(_NamedBaseModel, _LagDateBaseModel): """ arg_option: str | None = None + # lag? + port_name: str | None = None class ConfigCycleTaskOutput(_NamedBaseModel): """ To create an instance of an output in a task in a cycle defined in a workflow file. """ - + port_name: str | None = None class ConfigCycleTask(_NamedBaseModel): """ @@ -175,7 +181,7 @@ class ConfigCycleTask(_NamedBaseModel): @field_validator("inputs", mode="before") @classmethod - def convert_cycle_task_inputs(cls, values) -> list[ConfigCycleTaskInput]: + def convert_cycle_task_inputs(cls, values) -> list[dict]: inputs = [] if values is None: return inputs @@ -260,6 +266,7 @@ class ConfigWorkflow(BaseModel): data: list[ConfigData] data_dict: dict = {} task_dict: dict = {} + computer: str | None = None @field_validator("start_date", "end_date", mode="before") @classmethod @@ -278,7 +285,7 @@ def to_core_workflow(self): self.task_dict = {task.name: task for task in self.tasks} core_cycles = [self._to_core_cycle(cycle) for cycle in self.cycles] - return core.Workflow(self.name, core_cycles) + return core.Workflow(self.name, core_cycles, computer=self.computer) def _to_core_cycle(self, cycle: ConfigCycle) -> core.Cycle: core_tasks = [self._to_core_task(task) for task in cycle.tasks] @@ -295,14 +302,14 @@ def _to_core_task(self, cycle_task: ConfigCycleTask) -> core.Task: if (data := self.data_dict.get(input_.name)) is None: msg = f"Task {cycle_task.name!r} has input {input_.name!r} that is not specied in the data section." raise ValueError(msg) - core_data = core.Data(input_.name, data.type, data.src, input_.lag, input_.date, input_.arg_option) + core_data = core.Data(input_.name, data.type, data.src, input_.lag, input_.date, input_.arg_option, input_.port_name) inputs.append(core_data) for output in cycle_task.outputs: if (data := self.data_dict.get(output.name)) is None: msg = f"Task {cycle_task.name!r} has output {output.name!r} that is not specied in the data section." raise ValueError(msg) - core_data = core.Data(output.name, data.type, data.src, [], [], None) + core_data = core.Data(output.name, data.type, data.src, [], [], None, output.port_name) outputs.append(core_data) for depend in cycle_task.depends: @@ -315,7 +322,10 @@ def _to_core_task(self, cycle_task: ConfigCycleTask) -> core.Task: inputs, outputs, dependencies, - self.task_dict[cycle_task.name].command_option, + self.task_dict[cycle_task.name].command_options, + self.task_dict[cycle_task.name].plugin, + self.task_dict[cycle_task.name].code, + self.task_dict[cycle_task.name].computer, ) diff --git a/src/wcflow/workgraph.py b/src/wcflow/workgraph.py new file mode 100644 index 0000000..261b235 --- /dev/null +++ b/src/wcflow/workgraph.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import aiida.common +import aiida.orm +import aiida_workgraph.engine.utils # type: ignore[import-untyped] +from aiida.plugins.factories import CalculationFactory +from aiida_workgraph import WorkGraph # type: ignore[import-untyped] + +from ._utils import OrmUtils + +if TYPE_CHECKING: + from aiida_workgraph.socket import TaskSocket # type: ignore[import-untyped] + + from wcflow import core + + +# This is hack to aiida-workgraph, merging this into aiida-workgraph properly would require +# some major refactor see issue https://github.com/aiidateam/aiida-workgraph/issues/168 +# It might be better to give up on the graph like construction and just create the task +# directly with inputs, arguments and outputs +def _prepare_for_shell_task(task: dict, kwargs: dict) -> dict: + """Prepare the inputs for ShellTask""" + from aiida.common import lang + from aiida.orm import AbstractCode + from aiida_shell.launch import convert_nodes_single_file_data, prepare_code + + command = kwargs.pop("command", None) + resolve_command = kwargs.pop("resolve_command", False) + metadata = kwargs.pop("metadata", {}) + # setup code + if isinstance(command, str): + computer = (metadata or {}).get("options", {}).pop("computer", None) + code = prepare_code(command, computer, resolve_command) + else: + lang.type_check(command, AbstractCode) + code = command + # update the tasks with links + nodes = convert_nodes_single_file_data(kwargs.pop("nodes", {})) + # find all keys in kwargs start with "nodes." + for key in list(kwargs.keys()): + if key.startswith("nodes."): + nodes[key[6:]] = kwargs.pop(key) + metadata.update({"call_link_label": task["name"]}) + + default_outputs = {"remote_folder", "remote_stash", "retrieved", "_outputs", "_wait", "stdout", "stderr"} + task_outputs = {task["outputs"][i]["name"] for i in range(len(task["outputs"]))} + task_outputs = task_outputs.union(set(kwargs.pop("outputs", []))) + missing_outputs = task_outputs.difference(default_outputs) + return { + "code": code, + "nodes": nodes, + "filenames": kwargs.pop("filenames", {}), + "arguments": kwargs.pop("arguments", []), + "outputs": list(missing_outputs), + "parser": kwargs.pop("parser", None), + "metadata": metadata or {}, + } + + +aiida_workgraph.engine.utils.prepare_for_shell_task = _prepare_for_shell_task + + +class AiidaWorkGraph: + def __init__(self, core_workflow: core.Workflow): + # Needed for date indexing + self._core_workflow = core_workflow + + self._validate_workflow() + # breakpoint() + + self._workgraph = WorkGraph(core_workflow.name) + + # stores the input data available on initialization + self._aiida_data_nodes: dict[str, aiida_workgraph.orm.Data] = {} + # stores the outputs sockets of tasks + self._aiida_socket_nodes: dict[str, TaskSocket] = {} + self._aiida_task_nodes: dict[str, aiida_workgraph.Task] = {} + + self._add_aiida_initial_data_nodes() + + self._add_aiida_task_nodes() + self._add_aiida_links() + + + + def _validate_workflow(self): + """Checks if the defined workflow is correctly referencing key names.""" + for cycle in self._core_workflow.unrolled_cycles: + try: + aiida.common.validate_link_label(cycle.name) + except ValueError as exception: + msg = f"Raised error when validating cycle name '{cycle.name}': {exception.args[0]}" + raise ValueError(msg) from exception + for task in cycle.unrolled_tasks: + try: + aiida.common.validate_link_label(task.name) + except ValueError as exception: + msg = f"Raised error when validating task name '{cycle.name}': {exception.args[0]}" + raise ValueError(msg) from exception + for input_ in task.unrolled_inputs: + try: + aiida.common.validate_link_label(task.name) + except ValueError as exception: + msg = f"Raised error when validating input name '{input_.name}': {exception.args[0]}" + raise ValueError(msg) from exception + for output in task.unrolled_outputs: + try: + aiida.common.validate_link_label(task.name) + except ValueError as exception: + msg = f"Raised error when validating output name '{output.name}': {exception.args[0]}" + raise ValueError(msg) from exception + # - Warning if output nodes that are overwritten by other tasks before usage, if this actually happens? + + def _add_aiida_initial_data_nodes(self): + """ + Nodes that correspond to data that are available on initialization of the workflow + """ + for cycle in self._core_workflow.unrolled_cycles: + for task in cycle.unrolled_tasks: + for input_ in task.unrolled_inputs: + if self._core_workflow.is_available_on_init(input_): + self._add_aiida_input_data_node(input_, task.plugin) + + @staticmethod + def parse_to_aiida_label(label: str) -> str: + return label.replace("-", "_").replace(" ", "_").replace(":", "_") + + @staticmethod + def get_aiida_label_from_unrolled_data(data: core.UnrolledData) -> str: + """ """ + return AiidaWorkGraph.parse_to_aiida_label(f"{data.name}_{data.unrolled_date}") + + @staticmethod + def get_aiida_label_from_unrolled_task(task: core.UnrolledTask) -> str: + """ """ + return AiidaWorkGraph.parse_to_aiida_label( + f"{task.unrolled_cycle.name}_" f"{task.unrolled_cycle.unrolled_date}_" f"{task.name}" + ) + + def _add_aiida_input_data_node(self, input_: core.UnrolledData, plugin_entry_point: str): + """ + Create an :class:`aiida.orm.Data` instance from this wc data instance. + + :param input: ... + """ + if plugin_entry_point == "shell": + orm_class = OrmUtils.convert_to_class(input_.type) + else: + # TODO catch if entry_point does not exist + workflow_class = CalculationFactory(plugin_entry_point) + + if (port := workflow_class.spec().inputs.get(input_.port_name, None)) is None: + raise ValueError(f"Port name {input_.port_name} for input {input_} does not exist. Here a list of supported port name {workflow_class.spec().inputs.keys()}") + + orm_class = OrmUtils.convert_to_class(input_.type) + + # Otherwise port.valid_type might be non-iterable + valid_types = port.valid_type if isinstance(port.valid_type, (list, tuple)) else (port.valid_type,) + if orm_class not in valid_types: + raise ValueError(f"For input {input_} the type {input_.type} was translated to {orm_class} and is not a supported type for port {input_.port_name}. Supported types are {valid_types}.") + + label = AiidaWorkGraph.get_aiida_label_from_unrolled_data(input_) + if orm_class is aiida.orm.FolderData: + orm_instance = aiida.orm.FolderData(tree=input_.src, label=label) + elif orm_class is aiida.orm.RemoteData: + # TODO: Add load or create logic for RemoteData from Rico's aiida-icon-clm repo here + computer = aiida.orm.load_computer(self._core_workflow.computer) + orm_instance = aiida.orm.RemoteData(remote_path=input_.src, label=label, computer=computer) + + else: + orm_instance = orm_class(input_.src, label=label) + + self._aiida_data_nodes[label] = orm_instance + + def _add_aiida_task_nodes(self): + for cycle in self._core_workflow.unrolled_cycles: + + for task in cycle.unrolled_tasks: + if task.plugin == "shell": + self._add_aiida_task_node(task) + else: + self._add_aiida_plugin_task_node(task) + # after creation we can link the wait_on tasks + for cycle in self._core_workflow.unrolled_cycles: + for task in cycle.unrolled_tasks: + self._link_wait_on_to_task(task) + + def _add_aiida_task_node(self, task: core.UnrolledTask): + label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task) + workgraph_task = self._workgraph.tasks.new( + "ShellJob", + name=label, + command=task.command, + # TODO pickling error + #metadata={"computer": aiida.orm.load_computer(task.computer)}, + #computer=task.computer, + #metadata={"computer": task.computer}, + #code=task.code, + code=aiida.orm.load_code(task.code), + ) + workgraph_task.set({"arguments": []}) + workgraph_task.set({"nodes": {}}) + self._aiida_task_nodes[label] = workgraph_task + + def _link_wait_on_to_task(self, task: core.UnrolledTask): + label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task) + workgraph_task = self._aiida_task_nodes[label] + wait_on_tasks = [] + for depend in task.unrolled_depends: + wait_on_task_label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(depend.depend_on_task) + wait_on_tasks.append(self._aiida_task_nodes[wait_on_task_label]) + workgraph_task.wait = wait_on_tasks + + def _add_aiida_plugin_task_node(self, task: core.UnrolledTask): + label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task) + workflow_class = CalculationFactory(task.plugin) + + full_label = f'{task.code}@{task.computer}' + + workgraph_task = self._workgraph.tasks.new( + workflow_class, + name=label, + #metadata={"computer": aiida.orm.load_computer(task.computer)}, + code=aiida.orm.load_code(full_label), + #computer=task.computer, + #metadata={"computer": task.computer}, + #code=task.code, + ) + self._aiida_task_nodes[label] = workgraph_task + + def _add_aiida_links(self): + for cycle in self._core_workflow.unrolled_cycles: + self._add_aiida_links_from_cycle(cycle) + + def _add_aiida_links_from_cycle(self, cycle: core.UnrolledCycle): + for task in cycle.unrolled_tasks: + for input_ in task.unrolled_inputs: + if task.plugin == "shell": + self._link_input_to_task(input_) + else: + self._plugin_link_input_to_task(input_) + for output in task.unrolled_outputs: + if task.plugin == "task": + self._link_output_to_task(output) + else: + self._plugin_link_output_to_task(output) + + def _link_input_to_task(self, input_: core.UnrolledData): + task_label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(input_.unrolled_task) + input_label = AiidaWorkGraph.get_aiida_label_from_unrolled_data(input_) + workgraph_task = self._aiida_task_nodes[task_label] + workgraph_task.inputs.new("Any", f"nodes.{input_label}") + workgraph_task.kwargs.append(f"nodes.{input_label}") + + # resolve data + if (data_node := self._aiida_data_nodes.get(input_label)) is not None: + if (nodes := workgraph_task.inputs.get("nodes")) is None: + msg = f"Workgraph task {workgraph_task.name!r} did not initialize input nodes in the workgraph before linking. This is a bug in the code, please contact the developers by making an issue." + raise ValueError(msg) + nodes.value.update({f"{input_label}": data_node}) + elif (output_socket := self._aiida_socket_nodes.get(input_label)) is not None: + self._workgraph.links.new(output_socket, workgraph_task.inputs[f"nodes.{input_label}"]) + else: + msg = f"Input data node {input_label!r} was neither found in socket nodes nor in data nodes. The task {task_label!r} must have dependencies on inputs before they are created." + raise ValueError(msg) + + # resolve arg_option + if (workgraph_task_arguments := workgraph_task.inputs.get("arguments")) is None: + msg = f"Workgraph task {workgraph_task.name!r} did not initialize arguments nodes in the workgraph before linking. This is a bug in the code, please contact devevlopers." + raise ValueError(msg) + if input_.arg_option is not None: + workgraph_task_arguments.value.append(f"{input_.arg_option}") + workgraph_task_arguments.value.append(f"{{{input_label}}}") + + def _link_output_to_task(self, output: core.UnrolledData): + workgraph_task = self._aiida_task_nodes[AiidaWorkGraph.get_aiida_label_from_unrolled_task(output.unrolled_task)] + output_label = AiidaWorkGraph.get_aiida_label_from_unrolled_data(output) + output_socket = workgraph_task.outputs.new("Any", output.src) + self._aiida_socket_nodes[output_label] = output_socket + + def _plugin_link_input_to_task(self, input_: core.UnrolledData): + task_label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(input_.unrolled_task) + input_label = AiidaWorkGraph.get_aiida_label_from_unrolled_data(input_) + workgraph_task = self._aiida_task_nodes[task_label] + + if (data_node := self._aiida_data_nodes.get(input_label)) is not None: + workgraph_task.set({input_.port_name: data_node}) + elif (output_socket := self._aiida_socket_nodes.get(input_label)) is not None: + workgraph_task.set({input_.port_name: output_socket}) + else: + msg = f"Input data node {input_label!r} was neither found in socket nodes nor in data nodes. The task {task_label!r} must have dependencies on inputs before they are created." + raise ValueError(msg) + + def _plugin_link_output_to_task(self, output: core.UnrolledData): + task_label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(output.unrolled_task) + workgraph_task = self._aiida_task_nodes[task_label] + output_socket = workgraph_task.outputs[output.port_name] + output_label = AiidaWorkGraph.get_aiida_label_from_unrolled_data(output) + self._aiida_socket_nodes[output_label] = output_socket + + def run( + self, + inputs: None | dict[str, Any] = None, + metadata: None | dict[str, Any] = None, + ) -> dict[str, Any]: + a = self._workgraph + return self._workgraph.run(inputs=inputs, metadata=metadata) + + def submit( + self, + *, + inputs: None | dict[str, Any] = None, + wait: bool = False, + timeout: int = 60, + metadata: None | dict[str, Any] = None, + ) -> dict[str, Any]: + return self._workgraph.submit(inputs=inputs, wait=wait, timeout=timeout, metadata=metadata) diff --git a/tests/files/configs/test_config_small.yml b/tests/files/configs/test_config_small.yml index cb7e2ef..ff652c3 100644 --- a/tests/files/configs/test_config_small.yml +++ b/tests/files/configs/test_config_small.yml @@ -2,40 +2,46 @@ start_date: 2026-01-01T00:00 end_date: 2026-06-01T00:00 cycles: - - bimonthly_tasks: - period: P2M + - single_cycle: tasks: - - icon: + - adder1: inputs: - - icon_restart: - arg_option: --restart - lag: -P2M + - a: + port_name: x + - b: + port_name: y outputs: - - icon_output - - icon_restart - - lastly: - tasks: - - cleanup: - depends: - - icon: - date: 2026-05-01T00:00 + - sum1: + port_name: sum + - adder2: + inputs: + - sum1: + port_name: x + - c: + port_name: y + outputs: + - sum2: + port_name: sum tasks: - - icon: - plugin: shell - command: $PWD/tests/files/scripts/icon.py - - postproc: - plugin: shell - command: $PWD/tests/files/scripts/postproc.py - - cleanup: - plugin: shell - command: $PWD/tests/files/scripts/cleanup.py + - adder1: + plugin: core.arithmetic.add + code: bash + computer: localhost + - adder2: + plugin: core.arithmetic.add + code: bash + computer: localhost data: - - icon_input: - type: file - src: $PWD/tests/files/data/input - - icon_output: - type: file - src: output - - icon_restart: - type: file - src: restart + - a: + type: int + src: 5 + - b: + type: int + src: 2 + - c: + type: int + src: 3 + - sum1: + type: int + - sum2: + type: int \ No newline at end of file diff --git a/tests/files/configs/test_config_small_icon_localhost.yml b/tests/files/configs/test_config_small_icon_localhost.yml new file mode 100644 index 0000000..9f72b0f --- /dev/null +++ b/tests/files/configs/test_config_small_icon_localhost.yml @@ -0,0 +1,70 @@ +--- +start_date: 2026-01-01T00:00 +end_date: 2026-06-01T00:00 +computer: localhost +cycles: + - single_cycle: + # period: P2M + tasks: + - icon_task: + inputs: + - icon_master_namelist: + port_name: master_namelist + - icon_model_namelist: + port_name: model_namelist + - grid_file: + port_name: dynamics_grid_file + - ecrad_data: + port_name: ecrad_data + - cloud_data: + port_name: cloud_opt_props + - wet_whatever: + port_name: dmin_wetgrowth_lookup + - rrtmg_sw: + port_name: rrtmg_sw + outputs: + - latest_restart_file: + port_name: latest_restart_file + - finish_status: + port_name: finish_status +tasks: + - icon_task: + plugin: aiida_icon.icon + code: icon + computer: localhost + # icon_master_namelist: + # src: /home/geiger_j/aiida_projects/aiida-icon-clm/git-repos/aiida-icon/examples/exclaim_R02B04/icon_master.namelist + # modify: + # model_min_rank: 2 + +data: + - icon_master_namelist: + src: /home/geiger_j/aiida_projects/aiida-icon-clm/git-repos/aiida-icon/examples/exclaim_R02B04/icon_master.namelist + type: file + # modify: + # model_min_rank: 2 + - icon_model_namelist: + src: /home/geiger_j/aiida_projects/aiida-icon-clm/git-repos/aiida-icon/tests/data/simple_icon_run/inputs/model.namelist + type: file + - grid_file: + src: /home/geiger_j/aiida_projects/aiida-icon-clm/git-repos/aiida-icon/tests/data/simple_icon_run/inputs/icon_grid_simple.nc + type: remote + - ecrad_data: + src: /home/geiger_j/aiida_projects/aiida-icon-clm/git-repos/aiida-icon/tests/data/simple_icon_run/inputs/ecrad_data.nc + type: remote + - cloud_data: + src: /home/geiger_j/aiida_projects/aiida-icon-clm/git-repos/aiida-icon/tests/data/simple_icon_run/inputs/ECHAM6_CldOptProps.nc + type: remote + - wet_whatever: + src: /home/geiger_j/aiida_projects/aiida-icon-clm/git-repos/aiida-icon/tests/data/simple_icon_run/inputs/dmin_wetgrowth_lookup.nc + type: remote + - rrtmg_sw: + src: /home/geiger_j/aiida_projects/aiida-icon-clm/git-repos/aiida-icon/tests/data/simple_icon_run/inputs/rrtmg_sw.nc + type: remote + - latest_restart_file: + type: file + src: . + - finish_status: + type: file + src: . + diff --git a/tests/test_wc_workflow.py b/tests/test_wc_workflow.py index a1ba38f..721dcfb 100644 --- a/tests/test_wc_workflow.py +++ b/tests/test_wc_workflow.py @@ -1,15 +1,28 @@ +import aiida import pytest from wcflow.parsing import load_workflow_config +from wcflow.workgraph import AiidaWorkGraph +aiida.load_profile() @pytest.fixture def config_file_small(): return "files/configs/" -@pytest.mark.parametrize( - "config_file", ["tests/files/configs/test_config_small.yml", "tests/files/configs/test_config_large.yml"] -) # , "tests/files/configs/test_config_large.yml"]) -def test_parse_config_file(config_file): - load_workflow_config(config_file) +#@pytest.mark.parametrize( +# "config_file", ["tests/files/configs/test_config_small.yml", "tests/files/configs/test_config_large.yml"] +#) +#def test_convert_config_file(config_file): +# config_workflow = load_workflow_config(config_file) +# core_workflow = config_workflow.to_core_workflow() +# AiidaWorkGraph(core_workflow) + + +@pytest.mark.parametrize("config_file", ["tests/files/configs/test_config_small.yml"]) +def test_run_config_file(config_file): + config_workflow = load_workflow_config(config_file) + core_workflow = config_workflow.to_core_workflow() + aiida_workflow = AiidaWorkGraph(core_workflow) + aiida_workflow.run()