Skip to content

Commit

Permalink
Use node.calcfunction decorator (#15)
Browse files Browse the repository at this point in the history
Add node.calcfunction and node.workfunction decorators.
  • Loading branch information
superstar54 authored Dec 8, 2023
1 parent dd04076 commit 06a27c5
Show file tree
Hide file tree
Showing 34 changed files with 4,298 additions and 1,172 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:

strategy:
matrix:
python-version: ['3.10']
python-version: ['3.9', '3.10', '3.11']

services:
postgres:
Expand Down
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/flake8
rev: '6.0.0'
hooks:
- id: flake8
args: ['--max-line-length=121', '--ignore=F821, F722, E203, W503']
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
Expand Down
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,12 @@ from aiida_worktree import node
from aiida.engine import calcfunction

# define add calcfunction node
@node()
@calcfunction
@node.calcfunction()
def add(x, y):
return x + y

# define multiply calcfunction node
@node()
@calcfunction
@node.calcfunction()
def multiply(x, y):
return x*y

Expand Down
3 changes: 3 additions & 0 deletions aiida_worktree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
from .node import Node
from .decorator import node, build_node


__version__ = "0.0.4"

__all__ = ["WorkTree", "Node", "node", "build_node"]
3 changes: 3 additions & 0 deletions aiida_worktree/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@
The commands need to be imported here for them to be registered with the top-level command group.
"""
from aiida_worktree.cli import cmd_tree


__all__ = ["cmd_tree"]
35 changes: 21 additions & 14 deletions aiida_worktree/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def add_input_recursive(inputs, port, prefix=None):

def build_node(ndata):
"""Register a node from a AiiDA component."""
from aiida.engine import calcfunction
from node_graph.decorator import create_node
from aiida_worktree.node import Node
import cloudpickle as pickle
Expand Down Expand Up @@ -250,20 +249,28 @@ class NodeDecoratorCollection:

__call__: Any = node # Alias '@node' to '@node.node'.

@staticmethod
def calcfunction(**kwargs):
def decorator(func):
# First, apply the calcfunction decorator
calcfunc_decorated = calcfunction(func)
# Then, apply node decorator
node_decorated = node(**kwargs)(calcfunc_decorated)

node = NodeDecoratorCollection()
return node_decorated

return decorator

@staticmethod
def workfunction(**kwargs):
def decorator(func):
# First, apply the workfunction decorator
calcfunc_decorated = workfunction(func)
node_decorated = node(**kwargs)(calcfunc_decorated)

if __name__ == "__main__":
from aiida.engine import calcfunction
from aiida_worktree.decorator import node
return node_decorated

@node(
identifier="MyAdd",
outputs=[["General", "result"]],
executor_type="calcfunction",
)
@calcfunction
def myadd(x, y):
return x + y
return decorator

print(myadd.node)

node = NodeDecoratorCollection()
16 changes: 5 additions & 11 deletions aiida_worktree/engine/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


from aiida.engine.processes.exit_code import ExitCode
from aiida.engine.processes.process import Process, ProcessState
from aiida.engine.processes.process import Process

from aiida.engine.processes.workchains.awaitable import (
Awaitable,
Expand All @@ -30,9 +30,6 @@
from aiida.engine.processes.workchains.workchain import Protect, WorkChainSpec
from aiida.engine import run_get_node

from os.path import splitext
import yaml


if t.TYPE_CHECKING:
from aiida.engine.runners import Runner # pylint: disable=unused-import
Expand Down Expand Up @@ -275,8 +272,10 @@ def _do_step(self) -> t.Any:

# If the worktree is finished or the result is an ExitCode, we exit by returning
if finished:
result = self.finalize()
return result
if isinstance(result, ExitCode):
return result
else:
return self.finalize()

if self._awaitables:
return Wait(self._do_step, "Waiting before next step")
Expand Down Expand Up @@ -684,8 +683,6 @@ def run_nodes(self, names):
self.to_context(**{name: process})
elif node["metadata"]["node_type"] in ["worktree"]:
# process = run_get_node(executor, *args, **kwargs)
from aiida_worktree.utils import merge_properties
from aiida.orm.utils.serialize import serialize

print("node type: worktree.")
wt = self.run_executor(executor, args, kwargs, var_args, var_kwargs)
Expand Down Expand Up @@ -859,9 +856,6 @@ def check_parent_state(self, name):
"SKIPPED",
"FAILED",
]:
# print(
# f" {name}: Input node {link['from_node']}, {self.ctx.nodes[link['from_node']]['state']} ."
# )
ready = False
return (
ready,
Expand Down
4 changes: 1 addition & 3 deletions aiida_worktree/executors/qe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import time
from aiida import engine, orm, load_profile
from aiida.orm import load_code
from aiida import load_profile

load_profile()

Expand Down
1 change: 0 additions & 1 deletion aiida_worktree/nodes/builtin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from aiida_worktree.node import Node
from aiida_worktree.executors.builtin import GatherWorkChain


class AiiDAGather(Node):
Expand Down
1 change: 0 additions & 1 deletion aiida_worktree/nodes/qe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from aiida_worktree.node import Node
from aiida import orm


class AiiDAKpoint(Node):
Expand Down
1 change: 0 additions & 1 deletion aiida_worktree/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def save(self):
"""
self.build_node_link()
self.build_connectivity()
print("exist_in_db: {}".format(self.exist_in_db()))
if self.exist_in_db() or self.restart_process is not None:
new_nodes, modified_nodes, update_metadata = self.check_diff(
self.restart_process
Expand Down
2 changes: 0 additions & 2 deletions aiida_worktree/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run(self):
"""
from aiida_worktree.engine.worktree import WorkTree as WorkTreeEngine
from aiida_worktree.utils import merge_properties
from aiida.orm.utils.serialize import serialize
from aiida.manage import manager

# One can not run again if the process is alreay created. otherwise, a new process node will
Expand Down Expand Up @@ -253,7 +252,6 @@ def play_nodes(self, nodes):

def reset(self):
"""Reset the worktree."""
from aiida.engine.processes import control

self.process = None
for node in self.nodes:
Expand Down
6 changes: 2 additions & 4 deletions docs/source/blog.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,12 @@
"from aiida.engine import calcfunction\n",
"\n",
"# define add node\n",
"@node()\n",
"@calcfunction\n",
"@node.calcfunction()\n",
"def add(x, y):\n",
" return x + y\n",
"\n",
"# define multiply node\n",
"@node()\n",
"@calcfunction\n",
"@node.calcfunction()\n",
"def multiply(x, y):\n",
" return x*y\n",
"\n"
Expand Down
3 changes: 1 addition & 2 deletions docs/source/concept/node.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ Decorate any Python function using the `node` decorator. To use the power of Aii
from aiida.engine import calcfunction
# define add calcfunction node
@node()
@calcfunction
@node.calcfunction()
def add(x, y):
return x + y
Expand Down
248 changes: 243 additions & 5 deletions docs/source/howto/continue_finished_worktree.ipynb

Large diffs are not rendered by default.

73 changes: 70 additions & 3 deletions docs/source/howto/ctx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@
"from aiida.orm import Int\n",
"\n",
"# define add node\n",
"@node()\n",
"@calcfunction\n",
"@node.calcfunction()\n",
"def add(x, y):\n",
" return x + y\n",
"\n",
Expand Down Expand Up @@ -188,7 +187,75 @@
"outputs": [
{
"data": {
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.43.0 (0)\n -->\n<!-- Title: %3 Pages: 1 -->\n<svg width=\"224pt\" height=\"419pt\"\n viewBox=\"0.00 0.00 224.00 419.48\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 415.48)\">\n<title>%3</title>\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-415.48 220,-415.48 220,4 -4,4\"/>\n<!-- N4231 -->\n<g id=\"node1\" class=\"node\">\n<title>N4231</title>\n<polygon fill=\"#e38851\" stroke=\"red\" stroke-width=\"6\" points=\"216,-291.74 0,-291.74 0,-238.74 216,-238.74 216,-291.74\"/>\n<text text-anchor=\"middle\" x=\"108\" y=\"-276.54\" font-family=\"Times,serif\" font-size=\"14.00\">WorkTree: test_worktree_ctx (4231)</text>\n<text text-anchor=\"middle\" x=\"108\" y=\"-261.54\" font-family=\"Times,serif\" font-size=\"14.00\">State: finished</text>\n<text text-anchor=\"middle\" x=\"108\" y=\"-246.54\" font-family=\"Times,serif\" font-size=\"14.00\">Exit Code: 0</text>\n</g>\n<!-- N4233 -->\n<g id=\"node3\" class=\"node\">\n<title>N4233</title>\n<polygon fill=\"#de707f\" fill-opacity=\"0.466667\" stroke=\"black\" stroke-width=\"0\" points=\"155.5,-172.74 60.5,-172.74 60.5,-119.74 155.5,-119.74 155.5,-172.74\"/>\n<text text-anchor=\"middle\" x=\"108\" y=\"-157.54\" font-family=\"Times,serif\" font-size=\"14.00\">add (4233)</text>\n<text text-anchor=\"middle\" x=\"108\" y=\"-142.54\" font-family=\"Times,serif\" font-size=\"14.00\">State: finished</text>\n<text text-anchor=\"middle\" x=\"108\" y=\"-127.54\" font-family=\"Times,serif\" font-size=\"14.00\">Exit Code: 0</text>\n</g>\n<!-- N4231&#45;&gt;N4233 -->\n<g id=\"edge3\" class=\"edge\">\n<title>N4231&#45;&gt;N4233</title>\n<path fill=\"none\" stroke=\"#000000\" stroke-dasharray=\"1,5\" d=\"M108,-238.56C108,-222.26 108,-200.92 108,-182.97\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"111.5,-182.86 108,-172.86 104.5,-182.86 111.5,-182.86\"/>\n<text text-anchor=\"middle\" x=\"147.5\" y=\"-209.54\" font-family=\"Times,serif\" font-size=\"14.00\">CALL_CALC</text>\n<text text-anchor=\"middle\" x=\"147.5\" y=\"-194.54\" font-family=\"Times,serif\" font-size=\"14.00\">add1</text>\n</g>\n<!-- N4230 -->\n<g id=\"node2\" class=\"node\">\n<title>N4230</title>\n<ellipse fill=\"#8cd499\" stroke=\"black\" stroke-width=\"0\" cx=\"108\" cy=\"-384.61\" rx=\"50.41\" ry=\"26.74\"/>\n<text text-anchor=\"middle\" x=\"108\" y=\"-388.41\" font-family=\"Times,serif\" font-size=\"14.00\">Int (4230)</text>\n<text text-anchor=\"middle\" x=\"108\" y=\"-373.41\" font-family=\"Times,serif\" font-size=\"14.00\">value: 2</text>\n</g>\n<!-- N4230&#45;&gt;N4231 -->\n<g id=\"edge1\" class=\"edge\">\n<title>N4230&#45;&gt;N4231</title>\n<path fill=\"none\" stroke=\"#000000\" stroke-dasharray=\"5,2\" d=\"M108,-357.56C108,-341.19 108,-319.85 108,-301.91\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"111.5,-301.81 108,-291.81 104.5,-301.81 111.5,-301.81\"/>\n<text text-anchor=\"middle\" x=\"152.5\" y=\"-328.54\" font-family=\"Times,serif\" font-size=\"14.00\">INPUT_WORK</text>\n<text text-anchor=\"middle\" x=\"152.5\" y=\"-313.54\" font-family=\"Times,serif\" font-size=\"14.00\">nt__ctx__x</text>\n</g>\n<!-- N4234 -->\n<g id=\"node4\" class=\"node\">\n<title>N4234</title>\n<ellipse fill=\"#8cd499\" stroke=\"black\" stroke-width=\"0\" cx=\"108\" cy=\"-26.87\" rx=\"50.41\" ry=\"26.74\"/>\n<text text-anchor=\"middle\" x=\"108\" y=\"-30.67\" font-family=\"Times,serif\" font-size=\"14.00\">Int (4234)</text>\n<text text-anchor=\"middle\" x=\"108\" y=\"-15.67\" font-family=\"Times,serif\" font-size=\"14.00\">value: 5</text>\n</g>\n<!-- N4233&#45;&gt;N4234 -->\n<g id=\"edge2\" class=\"edge\">\n<title>N4233&#45;&gt;N4234</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M108,-119.48C108,-103.24 108,-82.01 108,-64.09\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"111.5,-63.97 108,-53.97 104.5,-63.97 111.5,-63.97\"/>\n<text text-anchor=\"middle\" x=\"134\" y=\"-90.54\" font-family=\"Times,serif\" font-size=\"14.00\">CREATE</text>\n<text text-anchor=\"middle\" x=\"134\" y=\"-75.54\" font-family=\"Times,serif\" font-size=\"14.00\">result</text>\n</g>\n</g>\n</svg>\n",
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.43.0 (0)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"224pt\" height=\"419pt\"\n",
" viewBox=\"0.00 0.00 224.00 419.48\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 415.48)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-415.48 220,-415.48 220,4 -4,4\"/>\n",
"<!-- N4231 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>N4231</title>\n",
"<polygon fill=\"#e38851\" stroke=\"red\" stroke-width=\"6\" points=\"216,-291.74 0,-291.74 0,-238.74 216,-238.74 216,-291.74\"/>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-276.54\" font-family=\"Times,serif\" font-size=\"14.00\">WorkTree: test_worktree_ctx (4231)</text>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-261.54\" font-family=\"Times,serif\" font-size=\"14.00\">State: finished</text>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-246.54\" font-family=\"Times,serif\" font-size=\"14.00\">Exit Code: 0</text>\n",
"</g>\n",
"<!-- N4233 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>N4233</title>\n",
"<polygon fill=\"#de707f\" fill-opacity=\"0.466667\" stroke=\"black\" stroke-width=\"0\" points=\"155.5,-172.74 60.5,-172.74 60.5,-119.74 155.5,-119.74 155.5,-172.74\"/>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-157.54\" font-family=\"Times,serif\" font-size=\"14.00\">add (4233)</text>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-142.54\" font-family=\"Times,serif\" font-size=\"14.00\">State: finished</text>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-127.54\" font-family=\"Times,serif\" font-size=\"14.00\">Exit Code: 0</text>\n",
"</g>\n",
"<!-- N4231&#45;&gt;N4233 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>N4231&#45;&gt;N4233</title>\n",
"<path fill=\"none\" stroke=\"#000000\" stroke-dasharray=\"1,5\" d=\"M108,-238.56C108,-222.26 108,-200.92 108,-182.97\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"111.5,-182.86 108,-172.86 104.5,-182.86 111.5,-182.86\"/>\n",
"<text text-anchor=\"middle\" x=\"147.5\" y=\"-209.54\" font-family=\"Times,serif\" font-size=\"14.00\">CALL_CALC</text>\n",
"<text text-anchor=\"middle\" x=\"147.5\" y=\"-194.54\" font-family=\"Times,serif\" font-size=\"14.00\">add1</text>\n",
"</g>\n",
"<!-- N4230 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>N4230</title>\n",
"<ellipse fill=\"#8cd499\" stroke=\"black\" stroke-width=\"0\" cx=\"108\" cy=\"-384.61\" rx=\"50.41\" ry=\"26.74\"/>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-388.41\" font-family=\"Times,serif\" font-size=\"14.00\">Int (4230)</text>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-373.41\" font-family=\"Times,serif\" font-size=\"14.00\">value: 2</text>\n",
"</g>\n",
"<!-- N4230&#45;&gt;N4231 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>N4230&#45;&gt;N4231</title>\n",
"<path fill=\"none\" stroke=\"#000000\" stroke-dasharray=\"5,2\" d=\"M108,-357.56C108,-341.19 108,-319.85 108,-301.91\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"111.5,-301.81 108,-291.81 104.5,-301.81 111.5,-301.81\"/>\n",
"<text text-anchor=\"middle\" x=\"152.5\" y=\"-328.54\" font-family=\"Times,serif\" font-size=\"14.00\">INPUT_WORK</text>\n",
"<text text-anchor=\"middle\" x=\"152.5\" y=\"-313.54\" font-family=\"Times,serif\" font-size=\"14.00\">nt__ctx__x</text>\n",
"</g>\n",
"<!-- N4234 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>N4234</title>\n",
"<ellipse fill=\"#8cd499\" stroke=\"black\" stroke-width=\"0\" cx=\"108\" cy=\"-26.87\" rx=\"50.41\" ry=\"26.74\"/>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-30.67\" font-family=\"Times,serif\" font-size=\"14.00\">Int (4234)</text>\n",
"<text text-anchor=\"middle\" x=\"108\" y=\"-15.67\" font-family=\"Times,serif\" font-size=\"14.00\">value: 5</text>\n",
"</g>\n",
"<!-- N4233&#45;&gt;N4234 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>N4233&#45;&gt;N4234</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M108,-119.48C108,-103.24 108,-82.01 108,-64.09\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"111.5,-63.97 108,-53.97 104.5,-63.97 111.5,-63.97\"/>\n",
"<text text-anchor=\"middle\" x=\"134\" y=\"-90.54\" font-family=\"Times,serif\" font-size=\"14.00\">CREATE</text>\n",
"<text text-anchor=\"middle\" x=\"134\" y=\"-75.54\" font-family=\"Times,serif\" font-size=\"14.00\">result</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f0f74ab60e0>"
]
Expand Down
490 changes: 485 additions & 5 deletions docs/source/howto/for.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 06a27c5

Please sign in to comment.