Skip to content

Commit

Permalink
Fix parallel runner (#21)
Browse files Browse the repository at this point in the history
* fix cache bug

* update

* update doc

* add parallel runner

* fix parallel runner
  • Loading branch information
goodwanghan authored Jun 17, 2020
1 parent 4f9d090 commit 3a3bffb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 42 deletions.
2 changes: 1 addition & 1 deletion adagio/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.7"
__version__ = "0.1.8"
39 changes: 10 additions & 29 deletions adagio/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from threading import Event, RLock
from traceback import StackSummary, extract_stack
Expand Down Expand Up @@ -151,48 +150,30 @@ def preprocess(self, wf: "_Workflow") -> List["_Task"]:
wf._register(temp)
if self._concurrency <= 1:
return temp
tempdict = {x.execution_id: x for x in temp}
down: Dict[str, Set[str]] = defaultdict(set)
up: Dict[str, Set[str]] = {}
q: List[str] = []
result: List["_Task"] = []
for t in temp:
u = set(x.execution_id for x in t.upstream) # noqa: C401
c = t.execution_id
up[c] = u
for x in u:
down[x].add(c)
if len(u) == 0:
q.append(c)
while len(q) > 0:
key = q.pop(0)
result.append(tempdict[key])
for d in down[key]:
up[d].remove(key)
if len(up[d]) == 0:
q.append(d)
return result
return [t for t in temp if len(t.upstream) == 0]

def run_tasks(self, tasks: List["_Task"]) -> None:
if self._concurrency <= 1:
for t in tasks:
self.run_single(t)
else:
with cf.ThreadPoolExecutor(max_workers=self._concurrency) as e:
jobs = []
for task in tasks:
jobs.append(e.submit(self.run_single, task))
return
with cf.ThreadPoolExecutor(max_workers=self._concurrency) as e:
jobs = [e.submit(self.run_single, task) for task in tasks]
while jobs:
for f in cf.as_completed(jobs):
jobs.remove(f)
try:
f.result()
for task in f.result().downstream:
jobs.append(e.submit(self.run_single, task))
except Exception:
self.context.abort()
raise

def run_single(self, task: "_Task") -> None:
def run_single(self, task: "_Task") -> "_Task":
task.update_by_cache()
task.run()
task.reraise()
return task


class SequentialExecutionEngine(ParallelExecutionEngine):
Expand Down
26 changes: 14 additions & 12 deletions tests/test_core_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from adagio.specs import InputSpec, OutputSpec, WorkflowSpec, _NodeSpec
from pytest import raises
from triad.exceptions import InvalidOperationError
from timeit import timeit


def test_task_skip():
Expand Down Expand Up @@ -335,16 +336,18 @@ def test_workflow_run_parallel():
s.add("f", wait_task1, "c")
hooks = MockHooks(None)
ctx = WorkflowContext(hooks=hooks)
ctx._engine = ParallelExecutionEngine(2, ctx)
with raises(NotImplementedError):
ctx.run(s, {})
ctx._engine = ParallelExecutionEngine(10, ctx)

def run():
with raises(NotImplementedError):
ctx.run(s, {})

t = timeit(run, number=1)
assert t < 0.2 # only a and b are executed

expected = {'a': 1}
for k, v in expected.items():
assert v == hooks.res[k]
# theoretically c is not determined
assert "d" in hooks.skipped
assert "e" in hooks.skipped
assert "f" in hooks.skipped
assert "b" in hooks.failed

# order of execution
Expand All @@ -358,11 +361,10 @@ def test_workflow_run_parallel():
hooks = MockHooks(None)
ctx = WorkflowContext(hooks=hooks)
ctx._engine = ParallelExecutionEngine(2, ctx)
ctx.run(s, {})
res = list(hooks.res.keys())
assert {"a", "b"} == set(res[0:2])
assert {"c", "d"} == set(res[2:4])
assert {"e", "f"} == set(res[4:6])
t = timeit(lambda: ctx.run(s, {}), number=1)
assert t < 0.4
assert 3 == hooks.res["e"]
assert 3 == hooks.res["f"]


def test_workflow_run_with_exception():
Expand Down

0 comments on commit 3a3bffb

Please sign in to comment.