Skip to content

Commit

Permalink
fix loading pythonjob task:
Browse files Browse the repository at this point in the history
1) serialize and deserialize the properties when saving and loading the task
2) update the outputs of the task when loading the task.
  • Loading branch information
superstar54 committed Sep 10, 2024
1 parent b56624a commit aa38666
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 27 deletions.
14 changes: 13 additions & 1 deletion aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,24 @@ def get_task(self, name: str):
task = PythonJob.from_dict(self.ctx._tasks[name])
else:
task = Task.from_dict(self.ctx._tasks[name])
# update task results
for output in task.outputs:
output.value = get_nested_dict(
self.ctx._tasks[name]["results"],
output.name,
default=output.value,
)
return task

def update_task(self, task: Task):
"""Update task in the context.
This is used in error handlers to update the task parameters."""
self.ctx._tasks[task.name]["properties"] = task.properties_to_dict()
from aiida_workgraph.utils import serialize_pythonjob_properties

tdata = task.to_dict()
if task.identifier.upper() == "PYTHONJOB":
serialize_pythonjob_properties(tdata)
self.ctx._tasks[task.name]["properties"] = tdata["properties"]
self.reset_task(task.name)

def get_task_state_info(self, name: str, key: str) -> str:
Expand Down
6 changes: 5 additions & 1 deletion aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,15 @@ def from_dict(cls, data: Dict[str, Any], task_pool: Optional[Any] = None) -> "Ta
Returns:
Node: An instance of Node initialized with the provided data."""
from aiida_workgraph.tasks import task_pool
from aiida.orm.utils.serialize import deserialize_unsafe

task = super().from_dict(data, node_pool=task_pool)
task.context_mapping = data.get("context_mapping", {})
task.waiting_on.add(data.get("wait", []))
task.process = data.get("process", None)
process = data.get("process", None)
if process and isinstance(process, str):
process = deserialize_unsafe(process)
task.process = process

return task

Expand Down
46 changes: 28 additions & 18 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,12 @@ def get_pythonjob_data(tdata: Dict[str, Any]) -> Dict[str, Any]:
if name == "computer":
break
if name in tdata["properties"]:
tdata["properties"][name]["value"] = tdata["properties"][name][
"value"
].value
value = tdata["properties"][name]["value"]
if isinstance(value, orm.Data):
value = value.value
elif value != {}:
raise ValueError(f"There something wrong with the input {name}")
tdata["properties"][name]["value"] = value
return tdata


Expand Down Expand Up @@ -382,6 +385,27 @@ def get_or_create_code(
return code


def serialize_pythonjob_properties(task):
"""Serialize the properties for PythonJob."""

from aiida_workgraph.orm.serializer import general_serializer

input_kwargs = []
for input in task["inputs"]:
if input["name"] == "_wait":
break
input_kwargs.append(input["name"])
for name in input_kwargs:
prop = task["properties"][name]
# if value is not None, not {}
if not (
prop["value"] is None
or isinstance(prop["value"], dict)
and prop["value"] == {}
):
prop["value"] = general_serializer(prop["value"])


def serialize_properties(wgdata):
"""Serialize the properties.
Because we use yaml (aiida's serialize) to serialize the data and
Expand All @@ -390,27 +414,13 @@ def serialize_properties(wgdata):
So, if a function is used as input, we needt to serialize the function.
For PythonJob, serialize the function inputs."""
from aiida_workgraph.orm.serializer import general_serializer
from aiida_workgraph.orm.function_data import PickledLocalFunction
import inspect

for _, task in wgdata["tasks"].items():
if task["metadata"]["node_type"].upper() == "PYTHONJOB":
# get the names kwargs for the PythonJob, which are the inputs before _wait
input_kwargs = []
for input in task["inputs"]:
if input["name"] == "_wait":
break
input_kwargs.append(input["name"])
for name in input_kwargs:
prop = task["properties"][name]
# if value is not None, not {}
if not (
prop["value"] is None
or isinstance(prop["value"], dict)
and prop["value"] == {}
):
prop["value"] = general_serializer(prop["value"])
serialize_pythonjob_properties(task)
else:
for _, prop in task["properties"].items():
if inspect.isfunction(prop["value"]):
Expand Down
15 changes: 8 additions & 7 deletions tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,13 @@ def add(x: str, y: str) -> str:

def test_exit_code(fixture_localhost, python_executable_path):
"""Test function with exit code."""
from numpy import array

@task.pythonjob(outputs=[{"name": "sum"}])
def add(x: int, y: int) -> int:
def add(x: array, y: array) -> array:
sum = x + y
if sum < 0:
exit_code = {"status": 410, "message": "Sum is negative"}
if (sum < 0).any():
exit_code = {"status": 410, "message": "Some elements are negative"}
return {"sum": sum, "exit_code": exit_code}
return {"sum": sum}

Expand All @@ -515,8 +516,8 @@ def handle_negative_sum(self, task_name: str):
wg.add_task(
add,
name="add1",
x=1,
y=-2,
x=array([1, 1]),
y=array([1, -2]),
computer="localhost",
code_label=python_executable_path,
)
Expand All @@ -531,8 +532,8 @@ def handle_negative_sum(self, task_name: str):
assert wg.process.base.links.get_outgoing().all()[0].node.exit_status == 410
assert (
wg.process.base.links.get_outgoing().all()[0].node.exit_message
== "Sum is negative"
== "Some elements are negative"
)
# the final task should have exit status 0
assert wg.tasks["add1"].node.exit_status == 0
assert wg.tasks["add1"].outputs["sum"].value.value == 3
assert (wg.tasks["add1"].outputs["sum"].value.value == array([2, 3])).all()

0 comments on commit aa38666

Please sign in to comment.