Skip to content

Commit

Permalink
Test the durable package
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Feb 3, 2024
1 parent d7ff563 commit 2f3ad50
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 20 deletions.
26 changes: 15 additions & 11 deletions src/dispatch/experimental/durable/durable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
from .registry import register_function


def durable(fn):
"""A decorator that makes generators serializable."""
return DurableFunction(fn)


class DurableFunction:
"""A wrapper for a generator function that wraps its generator instances
with a DurableGenerator.
Attributes:
fn (FunctionType): A generator function.
key (str): A key that uniquely identifies the function.
fn: A generator function.
key: A key that uniquely identifies the function.
"""

def __init__(self, fn: FunctionType):
Expand All @@ -23,8 +18,17 @@ def __init__(self, fn: FunctionType):

def __call__(self, *args, **kwargs):
result = self.fn(*args, **kwargs)
if isinstance(result, GeneratorType):
return DurableGenerator(result, self.key, args, kwargs)
if not isinstance(result, GeneratorType):
raise NotImplementedError(
"only synchronous generator functions are supported"
)
return DurableGenerator(result, self.key, args, kwargs)


# TODO: support native coroutines
raise NotImplementedError
def durable(fn) -> DurableFunction:
"""Returns a "durable" function that creates serializable generators.
Args:
fn: A generator function.
"""
return DurableFunction(fn)
32 changes: 24 additions & 8 deletions src/dispatch/experimental/durable/generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import GeneratorType, TracebackType, CodeType, FrameType
from typing import Generator, TypeVar
from typing import Generator, TypeVar, Any
from .registry import lookup_function
from . import frame as ext

Expand All @@ -10,13 +10,29 @@


class DurableGenerator(Generator[_YieldT, _SendT, _ReturnT]):
"""A generator that can be pickled."""

def __init__(self, gen: GeneratorType, key, args, kwargs):
self.generator = gen

# Capture the information necessary to be able to create a
# new instance of the generator.
"""A wrapper for a generator that makes it serializable (can be pickled).
Instances behave like the generators they wrap.
Attributes:
generator: The wrapped generator.
key: A unique identifier for the function that created this generator.
args: Positional arguments to the function that created this generator.
kwargs: Keyword arguments to the function that created this generator.
"""

generator: GeneratorType
key: str
args: list[Any]
kwargs: dict[str, Any]

def __init__(
self,
generator: GeneratorType,
key: str,
args: list[Any],
kwargs: dict[str, Any],
):
self.generator = generator
self.key = key
self.args = args
self.kwargs = kwargs
Expand Down
23 changes: 22 additions & 1 deletion src/dispatch/experimental/durable/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@


def register_function(fn: FunctionType) -> str:
"""Register a generator function.
Args:
fn: The function to register.
Returns:
str: Unique identifier for the function.
Raises:
ValueError: The function conflicts with another registered function.
"""
# We need to be able to refer to the function in the serialized
# representation, and the key needs to be stable across interpreter
# invocations. Use the code object's fully-qualified name for now.
Expand All @@ -16,9 +27,19 @@ def register_function(fn: FunctionType) -> str:
raise ValueError(f"durable function already registered with key {key}")

_REGISTRY[key] = fn

return key


def lookup_function(key: str) -> FunctionType:
"""Lookup a previously registered function.
Args:
key: Unique identifier for the function.
Returns:
FunctionType: The associated function.
Raises:
KeyError: A function has not been registered with this key.
"""
return _REGISTRY[key]
117 changes: 117 additions & 0 deletions tests/dispatch/experimental/durable/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
import pickle
import types
import warnings
from dispatch.experimental.durable import durable


Expand Down Expand Up @@ -57,3 +59,118 @@ def test_nested(self):
for j in range(i, len(expect)):
assert next(g) == expect[j]
assert next(g2) == expect[j]

def test_export_gi_fields(self):
g = nested_generators(1)
underlying = g.generator

self.assertIsInstance(g.gi_frame, types.FrameType)
self.assertIs(g.gi_frame, underlying.gi_frame)

self.assertIsInstance(g.gi_code, types.CodeType)
self.assertIs(g.gi_code, underlying.gi_code)

def check():
self.assertEqual(g.gi_running, underlying.gi_running)
self.assertEqual(g.gi_suspended, underlying.gi_suspended)
self.assertIs(g.gi_yieldfrom, underlying.gi_yieldfrom)

check()
for _ in g:
check()
check()

def test_name_conflict(self):
@durable
def durable_generator():
yield 1

with self.assertRaises(ValueError):

@durable
def durable_generator():
yield 2

def test_two_way(self):
@durable
def two_way(a):
b = yield a * 10
c = yield b * 10
return (yield c * 10)

input = 1
sends = [2, 3, 4]
yields = [10, 20, 30]
output = 4

g = two_way(1)

actual_yields = []
actual_return = None

try:
i = 0
send = None
while True:
next_value = g.send(send)
actual_yields.append(next_value)
send = sends[i]
i += 1
except StopIteration as e:
actual_return = e.value

self.assertEqual(actual_yields, yields)
self.assertEqual(actual_return, output)

def test_throw(self):
warnings.filterwarnings("ignore", category=DeprecationWarning) # FIXME

ok = False

@durable
def check_throw():
try:
yield
except RuntimeError:
nonlocal ok
ok = True

g = check_throw()
next(g)
try:
g.throw(RuntimeError)
except StopIteration:
pass
self.assertTrue(ok)

def test_close(self):
ok = False

@durable
def check_close():
try:
yield
except GeneratorExit:
nonlocal ok
ok = True
raise

g = check_close()
next(g)
g.close()
self.assertTrue(ok)

def test_not_a_synchronous_generator(self):
@durable
def regular():
pass

@durable
async def async_generator():
yield

with self.assertRaises(NotImplementedError):
regular()

with self.assertRaises(NotImplementedError):
async_generator()

0 comments on commit 2f3ad50

Please sign in to comment.