Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive WorkGraph with Graph Builder can only be run for the amount of redefinition of it #333

Open
agoscinski opened this issue Sep 20, 2024 · 3 comments · May be fixed by #336
Open

Comments

@agoscinski
Copy link
Contributor

agoscinski commented Sep 20, 2024

Because @giovannipizzi mentioned that for a project that need to use recursive workflows I wanted to do an example for recursive workflows. It kind of works, but there are some problems with referring to the the graph builder itself

from aiida_workgraph import task, WorkGraph
from aiida import load_profile
load_profile()

@task.calcfunction
def identity(x):
    return x.clone()

@task.calcfunction
def multiply(x, y):
    return x*y

@task.graph_builder(outputs = [{"name": "result", "from": "multiply.result"}])
def recursive(n):
    wg = WorkGraph()
    if n == 1:
        mytask = wg.add_task(identity, name="multiply", x=n)
        return wg 
    t = wg.add_task(recursive, name="recursive", n=n-1)
    wg.add_task(multiply, name="multiply", x=n, y=t.outputs["result"])
    return wg

I have to rerun the recursive definition the amount of time I want to use it in the recursion

wg = WorkGraph()
mytask = wg.add_task(recursive, n=3) # I have to run the above cell defining recursive 3 times so this works
wg.run()
@agoscinski
Copy link
Contributor Author

The error message is otherwise

  File "/Users/alexgo/code/aiida-workgraph/.pixi/envs/docs-env/lib/python3.12/site-packages/plumpy/process_states.py", line 228, in execute
    result = self.run_fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alexgo/code/aiida-workgraph/aiida_workgraph/engine/workgraph.py", line 285, in run
    return self._do_step()
           ^^^^^^^^^^^^^^^
  File "/Users/alexgo/code/aiida-workgraph/aiida_workgraph/engine/workgraph.py", line 300, in _do_step
    self.continue_workgraph()
  File "/Users/alexgo/code/aiida-workgraph/aiida_workgraph/engine/workgraph.py", line 705, in continue_workgraph
    self.run_tasks(task_to_run)
  File "/Users/alexgo/code/aiida-workgraph/aiida_workgraph/engine/workgraph.py", line 1109, in run_tasks
    wg = self.run_executor(executor, [], kwargs, var_args, var_kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alexgo/code/aiida-workgraph/aiida_workgraph/engine/workgraph.py", line 1450, in run_executor
    return executor(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/lh/d5j2y3816xg0qffzv9bqlx2c0000gn/T/ipykernel_70299/3807790678.py", line 7, in recursive
    t = wg.add_task(recursive, name="recursive", n=n-1)
                    ^^^^^^^^^
NameError: name 'recursive' is not defined

@superstar54
Copy link
Member

superstar54 commented Sep 20, 2024

Nice try! The error is because inside the graph_builder (local scope), there is no recursive definition.

In order to create recursive WorkGraph, you need to define recursive in a package, and then import it inside the recursive function.

One quick example is that you add this function into the aiida_workgraph.tasks.test file.

@task.graph_builder(outputs = [{"name": "result", "from": "multiply.result"}])
def recursive(n):
    from aiida_workgraph import WorkGraph
    from aiida_workgraph.tasks.test import recursive
    wg = WorkGraph()
    if n == 1:
        mytask = wg.add_task(identity, name="multiply", x=n)
        return wg 
    t = wg.add_task(recursive, name="recursive", n=n-1)
    wg.add_task(multiply, name="multiply", x=n, y=t.outputs["result"])
    return wg

@superstar54
Copy link
Member

superstar54 commented Sep 20, 2024

Here is another working version; however, I am not very clear why it works.

from aiida_workgraph import task, WorkGraph
from aiida import load_profile
load_profile()

@task.calcfunction
def identity(x):
    return x.clone()

@task.calcfunction
def multiply(x, y):
    return x*y

def recursive(n):
    wg = WorkGraph()
    if n == 1:
        mytask = wg.add_task(identity, name="multiply", x=n)
        return wg 
    Task = task.graph_builder(outputs = [{"name": "result", "from": "multiply.result"}])(recursive)
    t = wg.add_task(Task, name=f"recursive_{n-1}", n=n-1)
    wg.add_task(multiply, name="multiply", x=n, y=t.outputs["result"])
    return wg

Task = task.graph_builder(outputs = [{"name": "result", "from": "multiply.result"}])(recursive)
wg = WorkGraph()
mytask = wg.add_task(Task, n=3, name="recursive_3")
wg.run()

@superstar54 superstar54 linked a pull request Sep 22, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants