Skip to content

Commit

Permalink
Merge pull request #79 from stealthrocket/remove-coroutine-decorator
Browse files Browse the repository at this point in the history
Remove coroutine decorator
  • Loading branch information
achille-roussel authored Feb 20, 2024
2 parents 4624d5f + 2955162 commit b3a0ca3
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 49 deletions.
60 changes: 57 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ This package implements the Dispatch SDK for Python.
- [Configuration](#configuration)
- [Integration with FastAPI](#integration-with-fastapi)
- [Local testing with ngrok](#local-testing-with-ngrok)
- [Durable coroutines for Python](#durable-coroutines-for-python)
- [Examples](#examples)
- [Contributing](#contributing)

## What is Dispatch?
Expand All @@ -45,7 +47,7 @@ The SDK allows Python applications to declare *Stateful Functions* that the
Dispatch scheduler can orchestrate. This is the bare minimum structure used
to declare stateful functions:
```python
@dispatch.function()
@dispatch.function
def action(msg):
...
```
Expand Down Expand Up @@ -94,7 +96,7 @@ import requests
app = FastAPI()
dispatch = Dispatch(app)

@dispatch.function()
@dispatch.function
def publish(url, payload):
r = requests.post(url, data=payload)
r.raise_for_status()
Expand Down Expand Up @@ -144,7 +146,59 @@ different value, but in this example it would be:
export DISPATCH_ENDPOINT_URL="https://f441-2600-1700-2802-e01f-6861-dbc9-d551-ecfb.ngrok-free.app"
```

### Examples
### Durable coroutines for Python

The `@dispatch.function` decorator can also be applied to Python coroutines
(a.k.a. *async* functions), in which case each await point on another
stateful function becomes a durability step in the execution: if the awaited
operation fails, it is automatically retried and the parent function is paused
until the result becomes available, or a permanent error is raised.

```python
@dispatch.function
async def pipeline(msg):
# Each await point is a durability step, the functions can be run across the
# fleet of service instances and retried as needed without losing track of
# progress through the function execution.
msg = await transform1(msg)
msg = await transform2(msg)
await publish(msg)

@dispatch.function
async def publish(msg):
# Each dispatch function runs concurrently to the others, even if it does
# blocking operations like this POST request, it does not prevent other
# concurrent operations from carrying on in the program.
r = requests.post("https://somewhere.com/", data=msg)
r.raise_for_status()

@dispatch.function
async def transform1(msg):
...

@dispatch.function
async def transform2(msg):
...
```

This model is composable and can be used to create fan-out/fan-in control flows.
`gather` can be used to wait on multiple concurrent calls to stateful functions,
for example:

```python
from dispatch import gather

@dispatch.function
async def process(msgs):
concurrent_calls = [transform(msg) for msg in msgs]
return await gather(*concurrent_calls)

@dispatch.function
async def transform(msg):
...
```

## Examples

Check out the [examples](examples/) directory for code samples to help you get
started with the SDK.
Expand Down
11 changes: 6 additions & 5 deletions examples/auto_retry/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def third_party_api_call(x):
return "SUCCESS"


# Use the `dispatch.function` decorator to mark a function as durable.
@dispatch.function()
# Use the `dispatch.function` decorator to declare a stateful function.
@dispatch.function
def some_logic():
print("Executing some logic")
x = rng.randint(0, 5)
Expand All @@ -56,8 +56,9 @@ def some_logic():
# This is a normal FastAPI route that handles regular traffic.
@app.get("/")
def root():
# Use the `dispatch` method to call the durable function. This call is
# non-blocking and returns immediately.
# Use the `dispatch` method to call the stateful function. This call is
# returns immediately after scheduling the function call, which happens in
# the background.
some_logic.dispatch()
# Sending an unrelated response immediately.
# Sending a response now that the HTTP handler has completed.
return "OK"
11 changes: 6 additions & 5 deletions examples/getting_started/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@
dispatch = Dispatch(app)


# Use the `dispatch.function` decorator to mark a function as durable.
@dispatch.function()
# Use the `dispatch.function` decorator declare a stateful function.
@dispatch.function
def publish(url, payload):
r = requests.post(url, data=payload)
r.raise_for_status()
Expand All @@ -77,8 +77,9 @@ def publish(url, payload):
# This is a normal FastAPI route that handles regular traffic.
@app.get("/")
def root():
# Use the `dispatch` method to call the durable function. This call is
# non-blocking and returns immediately.
# Use the `dispatch` method to call the stateful function. This call is
# returns immediately after scheduling the function call, which happens in
# the background.
publish.dispatch("https://httpstat.us/200", {"hello": "world"})
# Sending an unrelated response immediately.
# Sending a response now that the HTTP handler has completed.
return "OK"
4 changes: 3 additions & 1 deletion src/dispatch/experimental/durable/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def durable(fn: Callable) -> Callable:
elif isinstance(fn, FunctionType):
return DurableFunction(fn)
else:
raise TypeError("unsupported callable")
raise TypeError(
f"cannot create a durable function from value of type {fn.__qualname__}"
)


class Serializable:
Expand Down
68 changes: 38 additions & 30 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import functools
import inspect
import logging
from functools import wraps
from types import FunctionType
from typing import Any, Callable, Dict, TypeAlias

Expand All @@ -23,12 +23,34 @@
"""


# https://stackoverflow.com/questions/653368/how-to-create-a-decorator-that-can-be-used-either-with-or-without-parameters
def decorator(f):
"""This decorator is intended to declare decorators that can be used with
or without parameters. If the decorated function is called with a single
callable argument, it is assumed to be a function and the decorator is
applied to it. Otherwise, the decorator is called with the arguments
provided and the result is returned.
"""

@wraps(f)
def method(self, *args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return f(self, args[0])

def wrapper(func):
return f(self, func, *args, **kwargs)

return wrapper

return method


class Function:
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK.
"""

__slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func", "call")
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func")

def __init__(
self,
Expand All @@ -42,11 +64,12 @@ def __init__(
self._client = client
self._name = name
self._primitive_func = primitive_func
self._func = func

# FIXME: is there a way to decorate the function at the definition
# without making it a class method?
self.call = durable(self._call_async)
if inspect.iscoroutinefunction(func):
self._func = durable(self._call_async)
else:
self._func = func

def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)
Expand Down Expand Up @@ -90,7 +113,7 @@ def _primitive_dispatch(self, input: Any = None) -> DispatchID:
return dispatch_id

async def _call_async(self, *args, **kwargs) -> Any:
"""Asynchronously call the function from a @dispatch.coroutine."""
"""Asynchronously call the function from a @dispatch.function."""
return await dispatch.coroutine.call(
self.build_call(*args, **kwargs, correlation_id=None)
)
Expand Down Expand Up @@ -142,39 +165,27 @@ def __init__(self, endpoint: str, client: Client | None):
self._endpoint = endpoint
self._client = client

def function(self) -> Callable[[FunctionType], Function]:
@decorator
def function(self, func: Callable) -> Function:
"""Returns a decorator that registers functions."""
return self._register_function(func)

# Note: the indirection here means that we can add parameters
# to the decorator later without breaking existing apps.
return self._register_function

def coroutine(self) -> Callable[[FunctionType], Function | FunctionType]:
"""Returns a decorator that registers coroutines."""

# Note: the indirection here means that we can add parameters
# to the decorator later without breaking existing apps.
return self._register_coroutine

def primitive_function(self) -> Callable[[PrimitiveFunctionType], Function]:
@decorator
def primitive_function(self, func: Callable) -> Function:
"""Returns a decorator that registers primitive functions."""

# Note: the indirection here means that we can add parameters
# to the decorator later without breaking existing apps.
return self._register_primitive_function
return self._register_primitive_function(func)

def _register_function(self, func: Callable) -> Function:
if inspect.iscoroutinefunction(func):
raise TypeError(
"async functions must be registered via @dispatch.coroutine"
)
return self._register_coroutine(func)

logger.info("registering function: %s", func.__qualname__)

# Register the function with the experimental.durable package, in case
# it's referenced from a @dispatch.coroutine.
func = durable(func)

@wraps(func)
def primitive_func(input: Input) -> Output:
try:
try:
Expand All @@ -196,14 +207,11 @@ def primitive_func(input: Input) -> Output:
return self._register(func, primitive_func)

def _register_coroutine(self, func: Callable) -> Function:
if not inspect.iscoroutinefunction(func):
raise TypeError(f"{func.__qualname__} must be an async function")

logger.info("registering coroutine: %s", func.__qualname__)

func = durable(func)

@functools.wraps(func)
@wraps(func)
def primitive_func(input: Input) -> Output:
return OneShotScheduler(func).run(input)

Expand Down
6 changes: 3 additions & 3 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _init_state(self, input: Input) -> State:

main = self.entry_point(*args, **kwargs)
if not isinstance(main, DurableCoroutine):
raise ValueError("entry point is not a @dispatch.coroutine")
raise ValueError("entry point is not a @dispatch.function")

return State(
version=sys.version,
Expand Down Expand Up @@ -255,7 +255,7 @@ def _run(self, input: Input) -> Output:
)
except Exception as e:
logger.exception(
f"@dispatch.coroutine: '{coroutine}' raised an exception"
f"@dispatch.function: '{coroutine}' raised an exception"
)
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)

Expand Down Expand Up @@ -317,7 +317,7 @@ def _run(self, input: Input) -> Output:
g = awaitable.__await__()
if not isinstance(g, DurableGenerator):
raise ValueError(
"gather awaitable is not a @dispatch.coroutine"
"gather awaitable is not a @dispatch.function"
)
child_id = state.next_coroutine_id
state.next_coroutine_id += 1
Expand Down
4 changes: 2 additions & 2 deletions tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def execute(self):

def test_simple_end_to_end(self):
# The FastAPI server.
@self.dispatch.function()
@self.dispatch.function
def my_function(name: str) -> str:
return f"Hello world: {name}"

Expand All @@ -73,7 +73,7 @@ def my_function(name: str) -> str:
self.assertEqual(any_unpickle(resp.exit.result.output), "Hello world: 52")

def test_simple_missing_signature(self):
@self.dispatch.function()
@self.dispatch.function
def my_function(name: str) -> str:
return f"Hello world: {name}"

Expand Down

0 comments on commit b3a0ca3

Please sign in to comment.