Skip to content

Commit

Permalink
refactor scheduler internals to factor logic into functions
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Jun 4, 2024
1 parent d0f9818 commit daedb94
Showing 1 changed file with 126 additions and 100 deletions.
226 changes: 126 additions & 100 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import sys
from dataclasses import dataclass, field
from types import coroutine
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -434,106 +435,11 @@ def _run(self, input: Input) -> Output:
logger.debug("running %s", coroutine)

assert coroutine.id not in state.suspended
coroutine_yield = run_coroutine(state, coroutine, pending_calls)

coroutine_yield = None
coroutine_result: Optional[CoroutineResult] = None
try:
coroutine_yield = coroutine.run()
except TailCall as tc:
coroutine_result = CoroutineResult(
coroutine_id=coroutine.id, call=tc.call, status=tc.status
)
except StopIteration as e:
coroutine_result = CoroutineResult(
coroutine_id=coroutine.id, value=e.value
)
except Exception as e:
logger.debug(
f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e
)
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)

# Handle coroutines that return or raise.
if coroutine_result is not None:
if coroutine_result.call is not None:
logger.debug(
"%s reset to %s", coroutine, coroutine_result.call.function
)
elif coroutine_result.error is not None:
logger.debug("%s raised %s", coroutine, coroutine_result.error)
else:
logger.debug("%s returned %s", coroutine, coroutine_result.value)

# If this is the main coroutine, we're done.
if coroutine.parent_id is None:
for suspended in state.suspended.values():
suspended.coroutine.close()
if coroutine_result.error is not None:
return Output.error(
Error.from_exception(coroutine_result.error)
)
if coroutine_result.call is not None:
return Output.tail_call(
tail_call=coroutine_result.call,
status=coroutine_result.status,
)
return Output.value(coroutine_result.value)

# Otherwise, notify the parent of the result.
try:
parent = state.suspended[coroutine.parent_id]
future = parent.result
assert future is not None
except (KeyError, AssertionError):
logger.warning("discarding %s", coroutine_result)
else:
future.add_result(coroutine_result)
if future.ready() and parent.id in state.suspended:
state.ready.insert(0, parent)
del state.suspended[parent.id]
logger.debug("parent %s is now ready", parent)
continue

# Handle coroutines that yield.
logger.debug("%s yielded %s", coroutine, coroutine_yield)
if isinstance(coroutine_yield, Call):
call = coroutine_yield
call_id = state.next_call_id
state.next_call_id += 1
call.correlation_id = correlation_id(coroutine.id, call_id)
logger.debug(
"enqueuing call %d (%s) for %s",
call_id,
call.function,
coroutine,
)
pending_calls.append(call)
coroutine.result = CallFuture()
state.suspended[coroutine.id] = coroutine
state.prev_callers.append(coroutine)
state.outstanding_calls += 1

elif isinstance(coroutine_yield, AllDirective):
children = spawn_children(state, coroutine, coroutine_yield.awaitables)

child_ids = [child.id for child in children]
coroutine.result = AllFuture(order=child_ids, waiting=set(child_ids))
state.suspended[coroutine.id] = coroutine

elif isinstance(coroutine_yield, AnyDirective):
children = spawn_children(state, coroutine, coroutine_yield.awaitables)

child_ids = [child.id for child in children]
coroutine.result = AnyFuture(order=child_ids, waiting=set(child_ids))
state.suspended[coroutine.id] = coroutine

elif isinstance(coroutine_yield, RaceDirective):
children = spawn_children(state, coroutine, coroutine_yield.awaitables)

coroutine.result = RaceFuture(waiting={child.id for child in children})
state.suspended[coroutine.id] = coroutine

else:
if coroutine_yield is not None:
if isinstance(coroutine_yield, Output):
return coroutine_yield
raise RuntimeError(
f"coroutine unexpectedly yielded '{coroutine_yield}'"
)
Expand Down Expand Up @@ -566,14 +472,135 @@ def _run(self, input: Input) -> Output:
)


def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]):
coroutine_yield = None
coroutine_result: Optional[CoroutineResult] = None
try:
coroutine_yield = coroutine.run()
except TailCall as tc:
coroutine_result = CoroutineResult(
coroutine_id=coroutine.id, call=tc.call, status=tc.status
)
except StopIteration as e:
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, value=e.value)
except Exception as e:
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)
logger.debug(
f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e
)

if coroutine_result is not None:
return set_coroutine_result(state, coroutine, coroutine_result)
logger.debug("%s yielded %s", coroutine, coroutine_yield)

if isinstance(coroutine_yield, Call):
return set_coroutine_call(state, coroutine, coroutine_yield, pending_calls)

if isinstance(coroutine_yield, AllDirective):
return set_coroutine_all(state, coroutine, coroutine_yield.awaitables)

if isinstance(coroutine_yield, AnyDirective):
return set_coroutine_any(state, coroutine, coroutine_yield.awaitables)

if isinstance(coroutine_yield, RaceDirective):
return set_coroutine_race(state, coroutine, coroutine_yield.awaitables)

return coroutine_yield


def set_coroutine_result(
state: State, coroutine: Coroutine, coroutine_result: CoroutineResult
):
if coroutine_result.call is not None:
logger.debug("%s reset to %s", coroutine, coroutine_result.call.function)
elif coroutine_result.error is not None:
logger.debug("%s raised %s", coroutine, coroutine_result.error)
else:
logger.debug("%s returned %s", coroutine, coroutine_result.value)

# If this is the main coroutine, we're done.
if coroutine.parent_id is None:
for suspended in state.suspended.values():
suspended.coroutine.close()
if coroutine_result.error is not None:
return Output.error(Error.from_exception(coroutine_result.error))
if coroutine_result.call is not None:
return Output.tail_call(
tail_call=coroutine_result.call, status=coroutine_result.status
)
return Output.value(coroutine_result.value)

# Otherwise, notify the parent of the result.
try:
parent = state.suspended[coroutine.parent_id]
future = parent.result
assert future is not None
except (KeyError, AssertionError):
logger.warning("discarding %s", coroutine_result)
else:
future.add_result(coroutine_result)
if future.ready() and parent.id in state.suspended:
state.ready.insert(0, parent)
del state.suspended[parent.id]
logger.debug("parent %s is now ready", parent)
return


def set_coroutine_call(
state: State, coroutine: Coroutine, call: Call, pending_calls: List[Call]
):
call_id = state.next_call_id
state.next_call_id += 1
call.correlation_id = correlation_id(coroutine.id, call_id)
logger.debug("enqueuing call %d (%s) for %s", call_id, call.function, coroutine)
pending_calls.append(call)
coroutine.result = CallFuture()
state.suspended[coroutine.id] = coroutine
state.prev_callers.append(coroutine)
state.outstanding_calls += 1
return


def set_coroutine_all(
state: State, coroutine: Coroutine, awaitables: Tuple[Awaitable[Any], ...]
):
children = spawn_children(state, coroutine, awaitables)
child_ids = [child.id for child in children]
coroutine.result = AllFuture(order=child_ids, waiting=set(child_ids))
state.suspended[coroutine.id] = coroutine
return


def set_coroutine_any(
state: State, coroutine: Coroutine, awaitables: Tuple[Awaitable[Any], ...]
):
children = spawn_children(state, coroutine, awaitables)
child_ids = [child.id for child in children]
coroutine.result = AnyFuture(order=child_ids, waiting=set(child_ids))
state.suspended[coroutine.id] = coroutine
return


def set_coroutine_race(
state: State, coroutine: Coroutine, awaitables: Tuple[Awaitable[Any], ...]
):
children = spawn_children(state, coroutine, awaitables)
coroutine.result = RaceFuture(waiting={child.id for child in children})
state.suspended[coroutine.id] = coroutine
return


def spawn_children(
state: State, coroutine: Coroutine, awaitables: Tuple[Awaitable[Any], ...]
) -> List[Coroutine]:
children = []

for awaitable in awaitables:
g = awaitable.__await__()

if not isinstance(g, DurableGenerator):
raise TypeError("awaitable is not a @dispatch.function")

child_id = state.next_coroutine_id
state.next_coroutine_id += 1
child = Coroutine(id=child_id, parent_id=coroutine.id, coroutine=g)
Expand All @@ -582,7 +609,6 @@ def spawn_children(

# Prepend children to get a depth-first traversal of coroutines.
state.ready = children + state.ready

return children


Expand Down

0 comments on commit daedb94

Please sign in to comment.