Skip to content

Commit

Permalink
Resume worktree when an available is finished, add set_from_protocol (#8
Browse files Browse the repository at this point in the history
)

- Execution order: In WorkChain, only when all awaitables are finished, does the process resume, and run the next step. In WorkTree, we want to launch a node when its incoming nodes are ready, thus as long as an awaitable (a sub-process) is finished, we resume the process, and check if any nodes are ready to run.
- Protocol: add set_from_protocol for node, only for the ProtocolMixin method.
  • Loading branch information
superstar54 authored Dec 1, 2023
1 parent bb7ccec commit 24fadb9
Show file tree
Hide file tree
Showing 24 changed files with 2,839 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .github/config/code-dos.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
label: dos-7.2
label: qe-7.2-dos
description: Quantum ESPRESSO dos.x
default_calc_job_plugin: quantumespresso.dos
computer: localhost
Expand Down
2 changes: 1 addition & 1 deletion .github/config/code-projwfc.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
label: projwfc-7.2
label: qe-7.2-projwfc
description: Quantum ESPRESSO projwfc.x
default_calc_job_plugin: quantumespresso.projwfc
computer: localhost
Expand Down
2 changes: 1 addition & 1 deletion .github/config/code-pw.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
label: pw-7.2
label: qe-7.2-pw
description: Quantum ESPRESSO pw.x
default_calc_job_plugin: quantumespresso.pw
computer: localhost
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Here is a detailed comparison between the ``WorkTree`` with two AiiDA built-in w
| ------------------------ | ---------------------- | ---------------------- | ---------------------- |
| Use Case | Short-running jobs | Long-running jobs | Long-running jobs |
| Checkpointing | ``No`` | Yes | Yes |
| Execution order | ``Sequential`` | ``Sequential`` | Directed Acyclic Graph |
| Non-blocking | ``No`` | Yes | Yes |
| Implementation | Easy | ``Difficult`` | Easy |
| Dynamic | ``No`` | ``No`` | Yes |
Expand Down Expand Up @@ -89,8 +90,5 @@ The node graph from the worktree process:
- For the moment, I did not create a `WorkTreeNode` for the `WorkTree` process. I used the `WorkChainNode`, because AiiDA hard codes the `WorkChainNode` for the command (report), graph etc.


## Bugs
- the `report` does not work.

## License
[MIT](http://opensource.org/licenses/MIT)
3 changes: 3 additions & 0 deletions aiida_worktree/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def add_input_recursive(inputs, port, prefix=None):
else:
port_name = f"{prefix}.{port.name}"
if isinstance(port, PortNamespace):
# TODO the default value is {} could cause problem, because the address of the dict is the same,
# so if you change the value of one port, the value of all the ports of other nodes will be changed
# consider to use None as default value
inputs.append(
["General", port_name, {"property": ["General", {"default": {}}]}]
)
Expand Down
63 changes: 42 additions & 21 deletions aiida_worktree/engine/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def load_instance_state(self, saved_state, load_context):
self.set_logger(self.node.logger)

if self._awaitables:
# this is a new runner, so we need to re-register the callbacks
self.ctx._awaitable_actions = []
self._action_awaitables()

def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]:
Expand Down Expand Up @@ -196,7 +198,6 @@ def _resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None:
:param awaitable: the awaitable to resolve
"""

ctx, key = self._resolve_nested_context(awaitable.key)

if awaitable.action == AwaitableAction.ASSIGN:
Expand Down Expand Up @@ -260,8 +261,10 @@ def _do_step(self) -> t.Any:
If any awaitables were created, the process will enter in the Wait state,
otherwise it will go to Continue.
"""

self._awaitables = []
# we will not remove the awaitables here,
# we resume the worktree in the callback function even
# there are some awaitables left
# self._awaitables = []
result: t.Any = None

try:
Expand Down Expand Up @@ -327,11 +330,15 @@ def _action_awaitables(self) -> None:
call it when the target is completed
"""
for awaitable in self._awaitables:
# if the waitable already has a callback, skip
if awaitable.pk in self.ctx._awaitable_actions:
continue
if awaitable.target == AwaitableTarget.PROCESS:
callback = functools.partial(
self.call_soon, self._on_awaitable_finished, awaitable
)
self.runner.call_on_process_finish(awaitable.pk, callback)
self.ctx._awaitable_actions.append(awaitable.pk)
else:
assert f"invalid awaitable target '{awaitable.target}'"

Expand Down Expand Up @@ -363,13 +370,22 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None:

self._resolve_awaitable(awaitable, value)

if self.state == ProcessState.WAITING and not self._awaitables:
# node finished, update the node state and result
# udpate the node state
self.update_node_state(awaitable.key)
# try to resume the worktree, if the worktree is already resumed
# by other awaitable, this will not work
try:
self.resume()
except Exception as e:
print(e)

def setup(self):
from node_graph.analysis import ConnectivityAnalysis
from aiida_worktree.utils import build_node_link

# track if the awaitable callback is added to the runner
self.ctx._awaitable_actions = []
self.ctx.new_data = dict()
self.ctx.input_nodes = dict()
if "input_file" in self.inputs:
Expand Down Expand Up @@ -473,26 +489,28 @@ def run_worktree(self):
print("node_to_run:", node_to_run)
self.run_nodes(node_to_run)

def update_node_state(self, name):
"""Update ndoe state if node is a Awaitable."""
node = self.ctx.nodes[name]
if (
node["metadata"]["node_type"]
in [
"calcfunction",
"workfunction",
"calcjob",
"workchain",
"worktree",
]
and node["state"] == "RUNNING"
):
self.set_node_result(node)

def is_worktree_finished(self):
"""Check if the worktree is finished.
For `while` worktree, we need check its conditions"""
is_finished = True
# print("is_worktree_finished:")
for name, node in self.ctx.nodes.items():
print(name, node["state"])
# if calc process, and has a process, check process state
if (
node["metadata"]["node_type"]
in [
"calcfunction",
"workfunction",
"calcjob",
"workchain",
"worktree",
]
and node["state"] == "RUNNING"
):
self.set_node_result(node)
# self.update_node_state(name)
if node["state"] in ["RUNNING", "CREATED", "READY"]:
is_finished = False
if is_finished:
Expand Down Expand Up @@ -625,10 +643,11 @@ def run_nodes(self, names):
process = self.submit(executor, *args, **kwargs)
node["process"] = process
self.ctx.nodes[name]["state"] = "RUNNING"
self.to_context(process=process)
self.to_context(**{name: process})
elif node["metadata"]["node_type"] in ["worktree"]:
# process = run_get_node(executor, *args, **kwargs)
from aiida_worktree.utils import merge_properties
from aiida.orm.utils.serialize import serialize

print("node type: worktree.")
wt = self.run_executor(executor, args, kwargs, var_args, var_kwargs)
Expand All @@ -641,10 +660,12 @@ def run_nodes(self, names):
all = {"nt": ntdata, "metadata": {"call_link_label": name}}
print("submit worktree: ")
process = self.submit(self.__class__, **all)
# save the ntdata to the process extras, so that we can load the worktree
process.base.extras.set("nt", serialize(ntdata))
node["process"] = process
# self.ctx.nodes[name]["group_outputs"] = executor.group_outputs
self.ctx.nodes[name]["state"] = "RUNNING"
return self.to_context(process=process)
self.to_context(**{name: process})
elif node["metadata"]["node_type"] in ["Normal"]:
print("node type: Normal.")
# normal function does not have a process
Expand Down
17 changes: 13 additions & 4 deletions aiida_worktree/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,23 @@ def __init__(self, **kwargs):
Initialize a Node instance.
"""
super().__init__(**kwargs)
self.to_ctx = []
self.wait = []
self.to_ctx = None
self.wait = None
self.process = None

def to_dict(self):
ndata = super().to_dict()
ndata["to_ctx"] = self.to_ctx
ndata["wait"] = self.wait
ndata["to_ctx"] = [] if self.to_ctx is None else self.to_ctx
ndata["wait"] = [] if self.wait is None else self.wait
ndata["process"] = self.process.uuid if self.process else None

return ndata

def set_from_protocol(self, *args, **kwargs):
"""For node support protocol, set the node from protocol data."""
from aiida_worktree.utils import get_executor, get_dict_from_builder

executor = get_executor(self.get_executor())[0]
builder = executor.get_builder_from_protocol(*args, **kwargs)
data = get_dict_from_builder(builder)
self.set(data)
2 changes: 0 additions & 2 deletions aiida_worktree/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
AiiDAAdd,
AiiDAGreater,
AiiDASumDiff,
AiiDAArithmeticAdd,
AiiDAArithmeticMultiplyAdd,
)
from .qe import (
Expand All @@ -36,7 +35,6 @@
AiiDAAdd,
AiiDAGreater,
AiiDASumDiff,
AiiDAArithmeticAdd,
AiiDAArithmeticMultiplyAdd,
AiiDAKpoint,
AiiDAPWPseudo,
Expand Down
28 changes: 0 additions & 28 deletions aiida_worktree/nodes/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,34 +246,6 @@ def get_executor(self):
}


class AiiDAArithmeticAdd(Node):

identifier: str = "AiiDAArithmeticAdd"
name = "AiiDAArithmeticAdd"
node_type = "calcjob"
catalog = "Test"
kwargs = ["code", "x", "y"]

def create_properties(self):
pass

def create_sockets(self):
self.inputs.clear()
self.outputs.clear()
self.inputs.new("General", "code")
inp = self.inputs.new("AiiDAInt", "x")
inp.add_property("AiiDAInt", "x", default=0.0)
inp = self.inputs.new("AiiDAInt", "y")
inp.add_property("AiiDAInt", "y", default=0.0)
self.outputs.new("AiiDAInt", "sum")

def get_executor(self):
return {
"path": "aiida.calculations.arithmetic.add",
"name": "ArithmeticAddCalculation",
}


class AiiDAArithmeticMultiplyAdd(Node):

identifier: str = "AiiDAArithmeticMultiplyAdd"
Expand Down
26 changes: 24 additions & 2 deletions aiida_worktree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ def get_nested_dict(d, name):

def update_nested_dict(d, key, value):
"""
d = {"base": {"pw": {"parameters": 1}}
d = {}
key = "base.pw.parameters"
value = 2
will give:
d = {"base": {"pw": {"parameters": 2}}
"""
keys = key.split(".")
current = d
current = {} if current is None else current
for k in keys[:-1]:
current = current.setdefault(k, {})
current[keys[-1]] = value
Expand All @@ -78,7 +81,16 @@ def update_nested_dict_with_special_keys(d):


def merge_properties(ntdata):
"""Merge properties."""
"""Merge sub properties to the root properties.
{
"base.pw.parameters": 2,
"base.pw.code": 1,
}
after merge:
{"base": {"pw": {"parameters": 2,
"code": 1}}
So that no "." in the key name.
"""
for name, node in ntdata["nodes"].items():
for key, prop in node["properties"].items():
if "." in key and prop["value"] not in [None, {}]:
Expand Down Expand Up @@ -128,6 +140,16 @@ def build_node_link(ntdata):
from_socket["links"].append(link)


def get_dict_from_builder(builder):
"""Transform builder to pure dict."""
from aiida.engine.processes.builder import ProcessBuilderNamespace

if isinstance(builder, ProcessBuilderNamespace):
return {k: get_dict_from_builder(v) for k, v in builder.items()}
else:
return builder


if __name__ == "__main__":
d = {
"base": {
Expand Down
2 changes: 2 additions & 0 deletions aiida_worktree/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def run(self):
the process and then calls the update method to update the state of the process.
"""
from aiida_worktree.engine.worktree import WorkTree
from aiida_worktree.utils import merge_properties
from aiida.orm.utils.serialize import serialize

ntdata = self.to_dict()
merge_properties(ntdata)
all = {"nt": ntdata}
_result, self.process = aiida.engine.run_get_node(WorkTree, **all)
self.process.base.extras.set("nt", serialize(ntdata))
Expand Down
Loading

0 comments on commit 24fadb9

Please sign in to comment.