From 2f3ad50ef22be809c36ec9410c73546881e046cb Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Sat, 3 Feb 2024 11:13:23 +1000 Subject: [PATCH] Test the durable package --- src/dispatch/experimental/durable/durable.py | 26 ++-- .../experimental/durable/generator.py | 32 +++-- src/dispatch/experimental/durable/registry.py | 23 +++- .../experimental/durable/test_generator.py | 117 ++++++++++++++++++ 4 files changed, 178 insertions(+), 20 deletions(-) diff --git a/src/dispatch/experimental/durable/durable.py b/src/dispatch/experimental/durable/durable.py index fe9d0dc7..546149a9 100644 --- a/src/dispatch/experimental/durable/durable.py +++ b/src/dispatch/experimental/durable/durable.py @@ -3,18 +3,13 @@ from .registry import register_function -def durable(fn): - """A decorator that makes generators serializable.""" - return DurableFunction(fn) - - class DurableFunction: """A wrapper for a generator function that wraps its generator instances with a DurableGenerator. Attributes: - fn (FunctionType): A generator function. - key (str): A key that uniquely identifies the function. + fn: A generator function. + key: A key that uniquely identifies the function. """ def __init__(self, fn: FunctionType): @@ -23,8 +18,17 @@ def __init__(self, fn: FunctionType): def __call__(self, *args, **kwargs): result = self.fn(*args, **kwargs) - if isinstance(result, GeneratorType): - return DurableGenerator(result, self.key, args, kwargs) + if not isinstance(result, GeneratorType): + raise NotImplementedError( + "only synchronous generator functions are supported" + ) + return DurableGenerator(result, self.key, args, kwargs) + - # TODO: support native coroutines - raise NotImplementedError +def durable(fn) -> DurableFunction: + """Returns a "durable" function that creates serializable generators. + + Args: + fn: A generator function. + """ + return DurableFunction(fn) diff --git a/src/dispatch/experimental/durable/generator.py b/src/dispatch/experimental/durable/generator.py index 9caf7945..d98d2a08 100644 --- a/src/dispatch/experimental/durable/generator.py +++ b/src/dispatch/experimental/durable/generator.py @@ -1,5 +1,5 @@ from types import GeneratorType, TracebackType, CodeType, FrameType -from typing import Generator, TypeVar +from typing import Generator, TypeVar, Any from .registry import lookup_function from . import frame as ext @@ -10,13 +10,29 @@ class DurableGenerator(Generator[_YieldT, _SendT, _ReturnT]): - """A generator that can be pickled.""" - - def __init__(self, gen: GeneratorType, key, args, kwargs): - self.generator = gen - - # Capture the information necessary to be able to create a - # new instance of the generator. + """A wrapper for a generator that makes it serializable (can be pickled). + Instances behave like the generators they wrap. + + Attributes: + generator: The wrapped generator. + key: A unique identifier for the function that created this generator. + args: Positional arguments to the function that created this generator. + kwargs: Keyword arguments to the function that created this generator. + """ + + generator: GeneratorType + key: str + args: list[Any] + kwargs: dict[str, Any] + + def __init__( + self, + generator: GeneratorType, + key: str, + args: list[Any], + kwargs: dict[str, Any], + ): + self.generator = generator self.key = key self.args = args self.kwargs = kwargs diff --git a/src/dispatch/experimental/durable/registry.py b/src/dispatch/experimental/durable/registry.py index a0f70206..01ef60a3 100644 --- a/src/dispatch/experimental/durable/registry.py +++ b/src/dispatch/experimental/durable/registry.py @@ -5,6 +5,17 @@ def register_function(fn: FunctionType) -> str: + """Register a generator function. + + Args: + fn: The function to register. + + Returns: + str: Unique identifier for the function. + + Raises: + ValueError: The function conflicts with another registered function. + """ # We need to be able to refer to the function in the serialized # representation, and the key needs to be stable across interpreter # invocations. Use the code object's fully-qualified name for now. @@ -16,9 +27,19 @@ def register_function(fn: FunctionType) -> str: raise ValueError(f"durable function already registered with key {key}") _REGISTRY[key] = fn - return key def lookup_function(key: str) -> FunctionType: + """Lookup a previously registered function. + + Args: + key: Unique identifier for the function. + + Returns: + FunctionType: The associated function. + + Raises: + KeyError: A function has not been registered with this key. + """ return _REGISTRY[key] diff --git a/tests/dispatch/experimental/durable/test_generator.py b/tests/dispatch/experimental/durable/test_generator.py index c8d3d35f..7e63ebe9 100644 --- a/tests/dispatch/experimental/durable/test_generator.py +++ b/tests/dispatch/experimental/durable/test_generator.py @@ -1,5 +1,7 @@ import unittest import pickle +import types +import warnings from dispatch.experimental.durable import durable @@ -57,3 +59,118 @@ def test_nested(self): for j in range(i, len(expect)): assert next(g) == expect[j] assert next(g2) == expect[j] + + def test_export_gi_fields(self): + g = nested_generators(1) + underlying = g.generator + + self.assertIsInstance(g.gi_frame, types.FrameType) + self.assertIs(g.gi_frame, underlying.gi_frame) + + self.assertIsInstance(g.gi_code, types.CodeType) + self.assertIs(g.gi_code, underlying.gi_code) + + def check(): + self.assertEqual(g.gi_running, underlying.gi_running) + self.assertEqual(g.gi_suspended, underlying.gi_suspended) + self.assertIs(g.gi_yieldfrom, underlying.gi_yieldfrom) + + check() + for _ in g: + check() + check() + + def test_name_conflict(self): + @durable + def durable_generator(): + yield 1 + + with self.assertRaises(ValueError): + + @durable + def durable_generator(): + yield 2 + + def test_two_way(self): + @durable + def two_way(a): + b = yield a * 10 + c = yield b * 10 + return (yield c * 10) + + input = 1 + sends = [2, 3, 4] + yields = [10, 20, 30] + output = 4 + + g = two_way(1) + + actual_yields = [] + actual_return = None + + try: + i = 0 + send = None + while True: + next_value = g.send(send) + actual_yields.append(next_value) + send = sends[i] + i += 1 + except StopIteration as e: + actual_return = e.value + + self.assertEqual(actual_yields, yields) + self.assertEqual(actual_return, output) + + def test_throw(self): + warnings.filterwarnings("ignore", category=DeprecationWarning) # FIXME + + ok = False + + @durable + def check_throw(): + try: + yield + except RuntimeError: + nonlocal ok + ok = True + + g = check_throw() + next(g) + try: + g.throw(RuntimeError) + except StopIteration: + pass + self.assertTrue(ok) + + def test_close(self): + ok = False + + @durable + def check_close(): + try: + yield + except GeneratorExit: + nonlocal ok + ok = True + raise + + g = check_close() + next(g) + g.close() + self.assertTrue(ok) + + def test_not_a_synchronous_generator(self): + @durable + def regular(): + pass + + @durable + async def async_generator(): + yield + + with self.assertRaises(NotImplementedError): + regular() + + with self.assertRaises(NotImplementedError): + async_generator()