Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ctx #1

Merged
merged 3 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ Here is a detailed comparison between the ``WorkTree`` with two AiiDA built-in w
| Implementation | Easy | ``Difficult`` | Easy |
| Dynamic | ``No`` | ``No`` | Yes |
| Ready to Use | Yes | ``No``,Need PYTHONPATH | Yes |
| Flow Control | All | `if`, `while` | `if` |
| Subprocesses Handling | ``No`` | Launches & waits | Launches & waits |
| Flow Control | All | `if`, `while` | `if`, `while` |
| Termination | ``Hard exit`` | ExitCode | ExitCode |
| Capabilities | Calls calcs and works | Calls any process | Calls any process |
| Data Passing | Direct passing | Context dictionary | Link |
Expand Down
89 changes: 81 additions & 8 deletions aiida_worktree/engine/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def setup(self):

# ntdata = jsonref.JsonRef.replace_refs(tntdata, loader = JsonYamlLoader())
build_node_link(ntdata)
self.init_ctx(ntdata["ctx"])
self.ctx.nodes = ntdata["nodes"]
self.ctx.links = ntdata["links"]
self.ctx.ctrl_links = ntdata["ctrl_links"]
Expand All @@ -400,9 +401,23 @@ def setup(self):
self.ctx.connectivity = nc.build_connectivity()
self.ctx.msgs = []
self.node.set_process_label(f"WorkTree: {self.ctx.worktree['name']}")
# while worktree
if self.ctx.worktree["is_while"]:
should_run = self.check_while_conditions()
if not should_run:
self.set_node_state(self.ctx.nodes.keys(), "SKIPPED")

def init_ctx(self, datas):
for key, value in datas.items():
self.ctx[key] = value

def launch_worktree(self):
print("launch_worktree: ")
self.report("Lanch worktree.")
if len(self.ctx.worktree["starts"]) > 0:
self.run_nodes(self.ctx.worktree["starts"])
self.ctx.worktree["starts"] = []
return
node_to_run = []
for name, node in self.ctx.nodes.items():
# update node state
Expand All @@ -416,7 +431,9 @@ def launch_worktree(self):
self.run_nodes(node_to_run)

def is_worktree_finished(self):
flag = True
"""Check if the worktree is finished.
For `while` worktree, we need check its conditions"""
is_finished = True
# print("is_worktree_finished:")
for name, node in self.ctx.nodes.items():
print(name, node["state"])
Expand Down Expand Up @@ -447,6 +464,7 @@ def is_worktree_finished(self):
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
print(f"Node: {name} finished.")
elif state == "EXCEPTED":
node["state"] = state
Expand All @@ -459,8 +477,27 @@ def is_worktree_finished(self):
)
print(f"Node: {name} failed.")
if node["state"] in ["RUNNING", "CREATED", "READY"]:
flag = False
return flag
is_finished = False
if is_finished:
if self.ctx.worktree["is_while"]:
should_run = self.check_while_conditions()
is_finished = not should_run
return is_finished

def check_while_conditions(self):
print("Is a while worktree")
condition_nodes = [c[0] for c in self.ctx.worktree["conditions"]]
self.run_nodes(condition_nodes)
conditions = [
self.ctx.nodes[c[0]]["results"][c[1]]
for c in self.ctx.worktree["conditions"]
]
print("conditions: ", conditions)
should_run = False not in conditions
if should_run:
self.reset()
self.set_node_state(condition_nodes, "SKIPPED")
return should_run

def run_nodes(self, names):
"""Run node
Expand Down Expand Up @@ -501,6 +538,7 @@ def run_nodes(self, names):
node["results"] = {node["outputs"][0]["name"]: results}
self.ctx.input_nodes[name] = results
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
# ValueError: attempted to add an input link after the process node was already stored.
# self.node.base.links.add_incoming(results, "INPUT_WORK", name)
elif node["metadata"]["node_type"] == "data":
Expand All @@ -510,6 +548,7 @@ def run_nodes(self, names):
node["process"] = results
self.ctx.new_data[name] = results
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
elif node["metadata"]["node_type"] in ["calcfunction", "workfunction"]:
print("node type: calcfunction/workfunction.")
kwargs.setdefault("metadata", {})
Expand All @@ -523,6 +562,7 @@ def run_nodes(self, names):
# print("results: ", results)
node["process"] = process
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
except Exception as e:
print(e)
self.report(e)
Expand Down Expand Up @@ -577,6 +617,7 @@ def run_nodes(self, names):
node["results"][node["outputs"][0]["name"]] = results
self.ctx.input_nodes[name] = results
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
# print("result from node: ", node["results"])
else:
print("node type: unknown.")
Expand All @@ -593,7 +634,9 @@ def get_inputs(self, node):
for input in node["inputs"]:
# print(f"input: {input['name']}")
if len(input["links"]) == 0:
inputs[input["name"]] = properties[input["name"]]["value"]
inputs[input["name"]] = self.update_ctx_variable(
properties[input["name"]]["value"]
)
elif len(input["links"]) == 1:
link = input["links"][0]
if self.ctx.nodes[link["from_node"]]["results"] is None:
Expand All @@ -613,19 +656,40 @@ def get_inputs(self, node):
link["from_socket"]
]
inputs[input["name"]] = value

for name in node["metadata"].get("args", []):
if name in inputs:
args.append(inputs[name])
else:
args.append(properties[name]["value"])
value = self.update_ctx_variable(properties[name]["value"])
args.append(value)
for name in node["metadata"].get("kwargs", []):
if name in inputs:
kwargs[name] = inputs[name]
else:
kwargs[name] = properties[name]["value"]
value = self.update_ctx_variable(properties[name]["value"])
kwargs[name] = value
return args, kwargs

def update_ctx_variable(self, value):
# replace context variables
"""Get value from context."""
if (
isinstance(value, str)
and value.strip().startswith("{{")
and value.strip().endswith("}}")
):
name = value[2:-2].strip()
if name not in self.ctx:
raise ValueError(f"Context variable {name} not found.")
return self.ctx[name]
else:
return value

def node_to_ctx(self, name):
items = self.ctx.nodes[name]["to_ctx"]
for item in items:
self.ctx[item[1]] = self.ctx.nodes[name]["results"][item[0]]

def check_node_state(self, name):
"""Check node states.

Expand Down Expand Up @@ -686,6 +750,9 @@ def check_parent_state(self, name):
# node = outgoing.get_node_by_label(output[0])
# outputs[output[2]] = getattr(node.outputs, output[1])
# return outputs
def reset(self):
print("Reset")
self.set_node_state(self.ctx.nodes.keys(), "CREATED")

def set_node_state(self, names, value):
"""Set node state"""
Expand All @@ -695,11 +762,17 @@ def set_node_state(self, names, value):
def finalize(self):
""""""
# expose group outputs
print("finalize")
group_outputs = {}
print("group outputs: ", self.ctx.worktree["metadata"]["group_outputs"])
for output in self.ctx.worktree["metadata"]["group_outputs"]:
print("output: ", output)
group_outputs[output[2]] = self.ctx.nodes[output[0]]["results"][output[1]]
if output[0] == "ctx":
group_outputs[output[2]] = self.ctx[output[1]]
else:
group_outputs[output[2]] = self.ctx.nodes[output[0]]["results"][
output[1]
]
self.out("group_outputs", group_outputs)
self.out("new_data", self.ctx.new_data)
self.report("Finalize")
Expand Down
4 changes: 3 additions & 1 deletion aiida_worktree/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .builtin import AiiDAGather
from .builtin import AiiDAGather, AiiDAToCtx, AiiDAFromCtx
from .test import (
AiiDAInt,
AiiDAFloat,
Expand All @@ -24,6 +24,8 @@

node_list = [
AiiDAGather,
AiiDAToCtx,
AiiDAFromCtx,
AiiDAInt,
AiiDAFloat,
AiiDAString,
Expand Down
47 changes: 46 additions & 1 deletion aiida_worktree/nodes/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,50 @@ def get_executor(self):
}


class AiiDAToCtx(Node):
"""AiiDAToCtx"""

identifier = "ToCtx"
name = "ToCtx"
node_type = "Control"
catalog = "AiiDA"
args = ["key", "value"]

def create_sockets(self):
self.inputs.clear()
self.outputs.clear()
self.inputs.new("General", "key")
self.inputs.new("General", "value")
self.outputs.new("General", "result")

def get_executor(self):
return {
"path": "builtins",
"name": "setattr",
}


class AiiDAFromCtx(Node):
"""AiiDAFromCtx"""

identifier = "FromCtx"
name = "FromCtx"
node_type = "Control"
catalog = "AiiDA"
args = ["key"]

def create_sockets(self):
self.inputs.clear()
self.outputs.clear()
self.inputs.new("General", "key")
self.outputs.new("General", "result")

def get_executor(self):
return {
"path": "builtins",
"name": "getattr",
}


if __name__ == "__main__":
print(gather_node)
print()
32 changes: 31 additions & 1 deletion aiida_worktree/properties/built_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def set_value(self, value):
self._value = value
if self.update is not None:
self.update()
elif (
isinstance(value, str)
and value.rstrip().startswith("{{")
and value.endswith("}}")
):
self._value = value
else:
raise Exception("{} is not a integer.".format(value))

Expand All @@ -45,6 +51,12 @@ def set_value(self, value):
self._value = value
if self.update is not None:
self.update()
elif (
isinstance(value, str)
and value.rstrip().startswith("{{")
and value.endswith("}}")
):
self._value = value
else:
raise Exception("{} is not a float.".format(value))

Expand All @@ -68,6 +80,12 @@ def set_value(self, value):
self._value = value
if self.update is not None:
self.update()
elif (
isinstance(value, str)
and value.rstrip().startswith("{{")
and value.endswith("}}")
):
self._value = value
else:
raise Exception("{} is not a bool.".format(value))

Expand All @@ -90,6 +108,12 @@ def set_value(self, value):
self._value = value
if self.update is not None:
self.update()
elif (
isinstance(value, str)
and value.rstrip().startswith("{{")
and value.endswith("}}")
):
self._value = value
else:
raise Exception("{} is not a string.".format(value))

Expand All @@ -108,10 +132,16 @@ def set_value(self, value):
self._value = orm.Dict(value)
if self.update is not None:
self.update()
if isinstance(value, orm.Dict):
elif isinstance(value, orm.Dict):
self._value = value
if self.update is not None:
self.update()
elif (
isinstance(value, str)
and value.rstrip().startswith("{{")
and value.endswith("}}")
):
self._value = value
else:
raise Exception("{} is not a dict.".format(value))

Expand Down
14 changes: 14 additions & 0 deletions aiida_worktree/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def __init__(self, name="WorkTree", **kwargs):
**kwargs: Additional keyword arguments to be passed to the NodeTree class.
"""
super().__init__(name, **kwargs)
self.ctx = {}
self.starts = []
self.is_while = False
self.conditions = []

def run(self):
"""
Expand Down Expand Up @@ -54,6 +58,16 @@ def submit(self, wait=False, timeout=60):
if wait:
self.wait(timeout=timeout)

def to_dict(self):
ntdata = super().to_dict()
for node in self.nodes:
ntdata["nodes"][node.name]["to_ctx"] = getattr(node, "to_ctx", [])
ntdata["ctx"] = self.ctx
ntdata["starts"] = self.starts
ntdata["is_while"] = self.is_while
ntdata["conditions"] = self.conditions
return ntdata

def wait(self, timeout=50):
"""
Periodically checks and waits for the AiiDA worktree process to finish until a given timeout.
Expand Down
Loading
Loading