Skip to content

Commit

Permalink
PythonJob spports exit code output (aiidateam#300)
Browse files Browse the repository at this point in the history
This PR adds a built-in output socket, `exit_code`, which serves as a mechanism for error handling and status reporting during task execution. This exit_code is `aiida.engine.ExitCode` with status and message. For the status, an integer value where 0 indicates a successful completion, and any non-zero value signals that an error occurred.

* Move all function related to PythonJob to the PythonJob Task, so that we can handle serialization and deserialization correctly for PythonJob Task
* Update the outputs of the task when loading the task in the error handler. so that we can use the outputs to update the input for next run.
  • Loading branch information
GeigerJ2 committed Sep 13, 2024
1 parent 15cd36d commit 0e281d0
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 64 deletions.
13 changes: 12 additions & 1 deletion aiida_workgraph/calculations/python_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Parser for an `PythonJob` job."""
from aiida.parsers.parser import Parser
from aiida_workgraph.orm import general_serializer
from aiida.engine import ExitCode


class PythonParser(Parser):
Expand Down Expand Up @@ -31,13 +32,14 @@ def parse(self, **kwargs):
"remote_folder",
"remote_stash",
"retrieved",
"exit_code",
]
]
# first we remove nested outputs, e.g., "add_multiply.add"
top_level_output_list = [
output for output in self.output_list if "." not in output["name"]
]

exit_code = 0
try:
with self.retrieved.base.repository.open("results.pickle", "rb") as handle:
results = pickle.load(handle)
Expand All @@ -49,6 +51,8 @@ def parse(self, **kwargs):
results[i], top_level_output_list[i]
)
elif isinstance(results, dict) and len(top_level_output_list) > 1:
# pop the exit code if it exists
exit_code = results.pop("exit_code", 0)
for output in top_level_output_list:
if output.get("required", False):
if output["name"] not in results:
Expand All @@ -62,6 +66,7 @@ def parse(self, **kwargs):
f"Found extra results that are not included in the output: {results.keys()}"
)
elif isinstance(results, dict) and len(top_level_output_list) == 1:
exit_code = results.pop("exit_code", 0)
# if output name in results, use it
if top_level_output_list[0]["name"] in results:
top_level_output_list[0]["value"] = self.serialize_output(
Expand All @@ -84,6 +89,12 @@ def parse(self, **kwargs):
)
for output in top_level_output_list:
self.out(output["name"], output["value"])
if exit_code:
if isinstance(exit_code, dict):
exit_code = ExitCode(exit_code["status"], exit_code["message"])
elif isinstance(exit_code, int):
exit_code = ExitCode(exit_code)
return exit_code
except OSError:
return self.exit_codes.ERROR_READING_OUTPUT_FILE
except ValueError as exception:
Expand Down
4 changes: 4 additions & 0 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def build_task_from_AiiDA(
def build_pythonjob_task(func: Callable) -> Task:
"""Build PythonJob task from function."""
from aiida_workgraph.calculations.python import PythonJob
from aiida_workgraph.tasks.pythonjob import PythonJob as PythonJobTask
from copy import deepcopy

# if the function is not a task, build a task from the function
Expand Down Expand Up @@ -310,6 +311,7 @@ def build_pythonjob_task(func: Callable) -> Task:
for output in tdata_py["outputs"]:
if output not in outputs:
outputs.append(output)
outputs.append({"identifier": "workgraph.any", "name": "exit_code"})
# change "copy_files" link_limit to 1e6
for input in inputs:
if input["name"] == "copy_files":
Expand All @@ -322,6 +324,8 @@ def build_pythonjob_task(func: Callable) -> Task:
tdata["outputs"] = outputs
tdata["kwargs"] = kwargs
tdata["task_type"] = "PYTHONJOB"
tdata["identifier"] = "workgraph.pythonjob"
tdata["node_class"] = PythonJobTask
task = create_task(tdata)
task.is_aiida_component = True
return task, tdata
Expand Down
10 changes: 9 additions & 1 deletion aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,12 +532,20 @@ def update_workgraph_from_base(self) -> None:
def get_task(self, name: str):
"""Get task from the context."""
task = Task.from_dict(self.ctx._tasks[name])
# update task results
for output in task.outputs:
output.value = get_nested_dict(
self.ctx._tasks[name]["results"],
output.name,
default=output.value,
)
return task

def update_task(self, task: Task):
"""Update task in the context.
This is used in error handlers to update the task parameters."""
self.ctx._tasks[task.name]["properties"] = task.properties_to_dict()
tdata = task.to_dict()
self.ctx._tasks[task.name]["properties"] = tdata["properties"]
self.reset_task(task.name)

def get_task_state_info(self, name: str, key: str) -> str:
Expand Down
8 changes: 6 additions & 2 deletions aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,15 @@ def from_dict(cls, data: Dict[str, Any], task_pool: Optional[Any] = None) -> "Ta
Returns:
Node: An instance of Node initialized with the provided data."""
from aiida_workgraph.tasks import task_pool
from aiida.orm.utils.serialize import deserialize_unsafe

task = super().from_dict(data, node_pool=task_pool)
task = GraphNode.from_dict(data, node_pool=task_pool)
task.context_mapping = data.get("context_mapping", {})
task.waiting_on.add(data.get("wait", []))
task.process = data.get("process", None)
process = data.get("process", None)
if process and isinstance(process, str):
process = deserialize_unsafe(process)
task.process = process

return task

Expand Down
69 changes: 69 additions & 0 deletions aiida_workgraph/tasks/pythonjob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any, Dict
from aiida import orm
from aiida_workgraph.orm.serializer import general_serializer
from aiida_workgraph.task import Task


class PythonJob(Task):
"""PythonJob Task."""

identifier = "workgraph.pythonjob"

@classmethod
def get_function_kwargs(cls, data) -> Dict[str, Any]:
input_kwargs = set()
for name in data["metadata"]["kwargs"]:
# all the kwargs are after computer is the input for the PythonJob, should be AiiDA Data node
if name == "computer":
break
input_kwargs.add(name)
return input_kwargs

def update_from_dict(cls, data: Dict[str, Any], **kwargs) -> "PythonJob":
"""Overwrite the update_from_dict method to handle the PythonJob data."""
cls.deserialize_pythonjob_data(data)
return super().update_from_dict(data)

def to_dict(self) -> Dict[str, Any]:
data = super().to_dict()
self.serialize_pythonjob_data(data)
return data

@classmethod
def serialize_pythonjob_data(cls, tdata: Dict[str, Any]):
"""Serialize the properties for PythonJob."""

input_kwargs = cls.get_function_kwargs(tdata)
for name in input_kwargs:
prop = tdata["properties"][name]
# if value is not None, not {}
if not (
prop["value"] is None
or isinstance(prop["value"], dict)
and prop["value"] == {}
):
prop["value"] = general_serializer(prop["value"])

@classmethod
def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None:
"""
Process the task data dictionary for a PythonJob.
It load the orignal Python data from the AiiDA Data node for the
args and kwargs of the function.
Args:
tdata (Dict[str, Any]): The input data dictionary.
Returns:
Dict[str, Any]: The processed data dictionary.
"""
input_kwargs = cls.get_function_kwargs(tdata)

for name in input_kwargs:
if name in tdata["properties"]:
value = tdata["properties"][name]["value"]
if isinstance(value, orm.Data):
value = value.value
elif value is not None and value != {}:
raise ValueError(f"There something wrong with the input {name}")
tdata["properties"][name]["value"] = value
61 changes: 4 additions & 57 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,38 +218,6 @@ def get_dict_from_builder(builder: Any) -> Dict:
return builder


def get_pythonjob_data(tdata: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the task data dictionary for a PythonJob.
It load the orignal Python data from the AiiDA Data node for the
args and kwargs of the function.
Args:
tdata (Dict[str, Any]): The input data dictionary.
Returns:
Dict[str, Any]: The processed data dictionary.
"""
for name in tdata["metadata"]["args"]:
if tdata["properties"][name]["value"] is None:
continue
if name in tdata["properties"]:
tdata["properties"][name]["value"] = tdata["properties"][name][
"value"
].value
for name in tdata["metadata"]["kwargs"]:
# all the kwargs are after computer is the input for the PythonJob, should be AiiDA Data node
if tdata["properties"][name]["value"] is None:
continue
if name == "computer":
break
if name in tdata["properties"]:
tdata["properties"][name]["value"] = tdata["properties"][name][
"value"
].value
return tdata


def serialize_workgraph_data(wgdata: Dict[str, Any]) -> Dict[str, Any]:
from aiida.orm.utils.serialize import serialize

Expand All @@ -270,8 +238,6 @@ def get_workgraph_data(process: Union[int, orm.Node]) -> Optional[Dict[str, Any]
return
for name, task in wgdata["tasks"].items():
wgdata["tasks"][name] = deserialize_unsafe(task)
if wgdata["tasks"][name]["metadata"]["node_type"].upper() == "PYTHONJOB":
get_pythonjob_data(wgdata["tasks"][name])
wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"])
return wgdata

Expand Down Expand Up @@ -388,33 +354,14 @@ def serialize_properties(wgdata):
save it to the node.base.extras. yaml can not handle the function
defined in a scope, e.g., local function in another function.
So, if a function is used as input, we needt to serialize the function.
For PythonJob, serialize the function inputs."""
from aiida_workgraph.orm.serializer import general_serializer
"""
from aiida_workgraph.orm.function_data import PickledLocalFunction
import inspect

for _, task in wgdata["tasks"].items():
if task["metadata"]["node_type"].upper() == "PYTHONJOB":
# get the names kwargs for the PythonJob, which are the inputs before _wait
input_kwargs = []
for input in task["inputs"]:
if input["name"] == "_wait":
break
input_kwargs.append(input["name"])
for name in input_kwargs:
prop = task["properties"][name]
# if value is not None, not {}
if not (
prop["value"] is None
or isinstance(prop["value"], dict)
and prop["value"] == {}
):
prop["value"] = general_serializer(prop["value"])
else:
for _, prop in task["properties"].items():
if inspect.isfunction(prop["value"]):
prop["value"] = PickledLocalFunction(prop["value"]).store()
for _, prop in task["properties"].items():
if inspect.isfunction(prop["value"]):
prop["value"] = PickledLocalFunction(prop["value"]).store()


def generate_bash_to_create_python_env(
Expand Down
62 changes: 59 additions & 3 deletions docs/source/built-in/pythonjob.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"id": "c6b83fb5",
"metadata": {},
"outputs": [
Expand All @@ -33,7 +33,7 @@
"Profile<uuid='57ccbf7d9e2b41b39edb2bfdaf725feb' name='default'>"
]
},
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -2368,14 +2368,70 @@
"id": "fe376995",
"metadata": {},
"source": [
"We can see that the `result.txt` file is retrieved from the remote computer and stored in the local repository."
"We can see that the `result.txt` file is retrieved from the remote computer and stored in the local repository.\n",
"\n",
"## Exit Code\n",
"\n",
"The `PythonJob` task includes a built-in output socket, `exit_code`, which serves as a mechanism for error handling and status reporting during task execution. This `exit_code` is an integer value where `0` indicates a successful completion, and any non-zero value signals that an error occurred.\n",
"\n",
"### How it Works:\n",
"When the function returns a dictionary with an `exit_code` key, the system automatically parses and uses this code to indicate the task's status. In the case of an error, the non-zero `exit_code` value helps identify the specific problem.\n",
"\n",
"\n",
"### Benefits of `exit_code`:\n",
"\n",
"1. **Error Reporting:** \n",
" If the task encounters an error, the `exit_code` can communicate the reason. This is helpful during process inspection to determine why a task failed.\n",
"\n",
"2. **Error Handling and Recovery:** \n",
" You can utilize `exit_code` to add specific error handlers for particular exit codes. This allows you to modify the task's parameters and restart it.\n",
"\n",
"\n",
"Below is an example Python function that uses `exit_code` to handle potential errors:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a96cbbcb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WorkGraph process created, PK: 146751\n",
"exit status: 1\n",
"exit message: Sum is negative\n"
]
}
],
"source": [
"from aiida_workgraph import WorkGraph, task\n",
"\n",
"@task.pythonjob(outputs=[{\"name\": \"sum\"}])\n",
"def add(x: int, y: int) -> int:\n",
" sum = x + y\n",
" if sum < 0:\n",
" exit_code = {\"status\": 1, \"message\": \"Sum is negative\"}\n",
" return {\"sum\": sum, \"exit_code\": exit_code}\n",
" return {\"sum\": sum}\n",
"\n",
"wg = WorkGraph(\"test_PythonJob\")\n",
"wg.add_task(add, name=\"add\", x=1, y=-2)\n",
"wg.submit(wait=True)\n",
"\n",
"print(\"exit status: \", wg.tasks[\"add\"].node.exit_status)\n",
"print(\"exit message: \", wg.tasks[\"add\"].node.exit_message)"
]
},
{
"cell_type": "markdown",
"id": "8d4d935b",
"metadata": {},
"source": [
"In this example, the task failed with `exit_code = 1` due to the condition `Sum is negative`, which is also reflected in the state message.\n",
"\n",
"## Define your data serializer\n",
"Workgraph search data serializer from the `aiida.data` entry point by the module name and class name (e.g., `ase.atoms.Atoms`). \n",
"\n",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"
"workgraph.test_greater" = "aiida_workgraph.tasks.test:TestGreater"
"workgraph.test_sum_diff" = "aiida_workgraph.tasks.test:TestSumDiff"
"workgraph.test_arithmetic_multiply_add" = "aiida_workgraph.tasks.test:TestArithmeticMultiplyAdd"
"workgraph.pythonjob" = "aiida_workgraph.tasks.pythonjob:PythonJob"

[project.entry-points."aiida_workgraph.property"]
"workgraph.any" = "aiida_workgraph.properties.builtins:PropertyAny"
Expand Down
Loading

0 comments on commit 0e281d0

Please sign in to comment.