From 568ae0cf6199937adf368d7d1dfdc3ca6acbefdb Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 31 Jan 2024 11:34:47 +1000 Subject: [PATCH 1/8] Add experimental multicolor package --- .gitignore | 4 +- src/dispatch/experimental/__init__.py | 0 .../experimental/multicolor/__init__.py | 4 + .../experimental/multicolor/compile.py | 291 +++++++ .../experimental/multicolor/desugar.py | 624 +++++++++++++++ .../experimental/multicolor/generator.py | 48 ++ src/dispatch/experimental/multicolor/parse.py | 48 ++ .../experimental/multicolor/template.py | 35 + .../experimental/multicolor/yields.py | 33 + tests/dispatch/__init__.py | 0 tests/dispatch/experimental/__init__.py | 0 .../experimental/multicolor/__init__.py | 0 .../experimental/multicolor/test_compile.py | 230 ++++++ .../experimental/multicolor/test_desugar.py | 749 ++++++++++++++++++ .../experimental/multicolor/test_generator.py | 63 ++ 15 files changed, 2128 insertions(+), 1 deletion(-) create mode 100644 src/dispatch/experimental/__init__.py create mode 100644 src/dispatch/experimental/multicolor/__init__.py create mode 100644 src/dispatch/experimental/multicolor/compile.py create mode 100644 src/dispatch/experimental/multicolor/desugar.py create mode 100644 src/dispatch/experimental/multicolor/generator.py create mode 100644 src/dispatch/experimental/multicolor/parse.py create mode 100644 src/dispatch/experimental/multicolor/template.py create mode 100644 src/dispatch/experimental/multicolor/yields.py create mode 100644 tests/dispatch/__init__.py create mode 100644 tests/dispatch/experimental/__init__.py create mode 100644 tests/dispatch/experimental/multicolor/__init__.py create mode 100644 tests/dispatch/experimental/multicolor/test_compile.py create mode 100644 tests/dispatch/experimental/multicolor/test_desugar.py create mode 100644 tests/dispatch/experimental/multicolor/test_generator.py diff --git a/.gitignore b/.gitignore index f416ef4e..8cce4a54 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *.pyc +*.so +*.egg-info __pycache__ -.proto \ No newline at end of file +.proto diff --git a/src/dispatch/experimental/__init__.py b/src/dispatch/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dispatch/experimental/multicolor/__init__.py b/src/dispatch/experimental/multicolor/__init__.py new file mode 100644 index 00000000..296a0ab9 --- /dev/null +++ b/src/dispatch/experimental/multicolor/__init__.py @@ -0,0 +1,4 @@ +from .compile import compile_function +from .yields import yields, CustomYield, GeneratorYield + +__all__ = ["compile_function", "yields", "CustomYield", "GeneratorYield"] diff --git a/src/dispatch/experimental/multicolor/compile.py b/src/dispatch/experimental/multicolor/compile.py new file mode 100644 index 00000000..ea8e8369 --- /dev/null +++ b/src/dispatch/experimental/multicolor/compile.py @@ -0,0 +1,291 @@ +import ast +import inspect +import os +from enum import Enum +from types import FunctionType, GeneratorType +from typing import cast +from .desugar import desugar_function +from .generator import is_generator, empty_generator +from .parse import parse_function, NoSourceError, repair_indentation +from .template import rewrite_template +from .yields import CustomYield, GeneratorYield + +TRACE = os.getenv("MULTICOLOR_TRACE", False) + + +def compile_function( + fn: FunctionType, decorator: FunctionType | None = None, cache_key: str = "default" +) -> FunctionType: + """Compile a regular function into a generator that yields data passed + to functions marked with the @multicolor.yields decorator. Decorated + functions can be called from anywhere in the call stack, and functions + in between do not have to be generators or async functions (coroutines). + + Example: + + @multicolor.yields(type="sleep") + def sleep(seconds): ... + + def parent(): + sleep(3) # yield point! + + def grandparent(): + parent() + + compiled_grandparent = multicolor.compile_function(grandparent) + generator = compiled_grandparent() + for item in generator: + print(item) # multicolor.CustomYield(type="sleep", args=[3]) + + Two-way data flow works as expected. At a yield point, generator.send(value) + can be used to send data back to the yield point and to resume execution. + The data sent back will be the return value of the function decorated with + @multicolor.yields. + + @multicolor.yields(type="add") + def add(a: int, b: int) -> int: + return a + b # default/synchronous implementation + + def scheduler(generator): + try: + send = None + while True: + item = generator.send(send) + match item: + case multicolor.CustomYield(type="add"): + a, b = item.args + print(f"adding {a} + {b}") + send = a + b + except StopIteration as e: + return e.value # return value + + def adder(a: int, b: int) -> int: + return add(a, b) + + compiled_adder = multicolor.compile_function(adder) + generator = compiled_adder(1, 2) + result = scheduler(generator) + print(result) # 3 + + The @multicolor.yields decorator does not change the implementation of + the function it decorates. If the function is run without being + compiled, the default implementation will be used instead: + + print(adder(1, 2)) # 3 + + The default implementation could also raise an error, to ensure that + the function is only ever called from a compiled function. + """ + compiled_fn, _ = compile_internal(fn, decorator, cache_key) + return compiled_fn + + +class FunctionColor(Enum): + REGULAR_FUNCTION = 0 + GENERATOR_FUNCTION = 1 + + +def compile_internal( + fn: FunctionType, decorator: FunctionType | None, cache_key: str +) -> tuple[FunctionType, FunctionColor]: + if hasattr(fn, "_multicolor_yield_type"): + raise ValueError("cannot compile a yield point directly") + + # Check if the function has already been compiled. + if hasattr(fn, "_multicolor_cache"): + try: + return fn._multicolor_cache[cache_key] + except KeyError: + pass + + # Parse an abstract syntax tree from the function source. + try: + root, fn_def = parse_function(fn) + except NoSourceError as e: + try: + # This can occur when compiling a nested function definition + # that was created by the desugaring pass. + if inspect.getsourcefile(fn) == "": + return fn, FunctionColor.GENERATOR_FUNCTION + except TypeError: + raise e + + # Determine what type of function we're working with. + color = FunctionColor.REGULAR_FUNCTION + if is_generator(fn_def): + color = FunctionColor.GENERATOR_FUNCTION + + if TRACE: + print("\n-------------------------------------------------") + print("[MULTICOLOR] COMPILING:") + print(repair_indentation(inspect.getsource(fn)).rstrip()) + + # De-sugar the AST to simplify subsequent transformations. + desugar_function(fn_def) + + if TRACE: + print("\n[MULTICOLOR] DESUGARED:") + print(ast.unparse(root)) + + # Handle generators by wrapping the values they yield. + generator_transformer = GeneratorTransformer() + root = generator_transformer.visit(root) + + # Replace explicit function calls with a gadget that resembles yield from. + call_transformer = CallTransformer() + root = call_transformer.visit(root) + + # If the function never yields it won't be considered a generator. + # Patch the function if necessary to yield from an empty generator, which + # turns it into a generator. + if not is_generator(fn_def): + empty = ast.Name(id="_multicolor_empty_generator", ctx=ast.Load()) + g = ast.Call(func=empty, args=[], keywords=[]) + fn_def.body.insert(0, ast.Expr(ast.YieldFrom(value=g))) + + name = fn_def.name + "__multicolor_" + cache_key + fn_def.name = name + + # Patch AST nodes that were inserted without location info. + ast.fix_missing_locations(root) + + if TRACE: + print("\n[MULTICOLOR] RESULT:") + print(ast.unparse(root)) + + # Make necessary objects/classes/functions available to the + # transformed function. + namespace = fn.__globals__ + namespace["_multicolor_empty_generator"] = empty_generator + namespace["_multicolor_no_source_error"] = NoSourceError + namespace["_multicolor_custom_yield"] = CustomYield + namespace["_multicolor_generator_yield"] = GeneratorYield + namespace["_multicolor_compile"] = compile_internal + namespace["_multicolor_generator_type"] = GeneratorType + namespace["_multicolor_decorator"] = decorator + namespace["_multicolor_cache_key"] = cache_key + namespace["_multicolor_generator_color"] = FunctionColor.GENERATOR_FUNCTION + + # Re-compile. + code = compile(root, filename="", mode="exec") + exec(code, namespace) + compiled_fn = namespace[name] + + # Apply the custom decorator, if applicable. + if decorator is not None: + compiled_fn = decorator(compiled_fn) + + # Cache the compiled function. + if hasattr(fn, "_multicolor_cache"): + cache = cast( + dict[str, tuple[FunctionType, FunctionColor]], fn._multicolor_cache + ) + else: + cache = {} + setattr(fn, "_multicolor_cache", cache) + cache[cache_key] = (compiled_fn, color) + + return compiled_fn, color + + +class GeneratorTransformer(ast.NodeTransformer): + """Wrap ast.Yield values in a GeneratorYield container.""" + + def visit_Yield(self, node: ast.Yield) -> ast.Yield: + value = node.value + if node.value is None: + value = ast.Constant(value=None) + + wrapped_value = ast.Call( + func=ast.Name(id="_multicolor_generator_yield", ctx=ast.Load()), + args=[], + keywords=[ast.keyword(arg="value", value=value)], + ) + return ast.Yield(value=wrapped_value) + + +class CallTransformer(ast.NodeTransformer): + """Replace explicit function calls with a gadget that recursively compiles + functions into generators and then replaces the function call with a + yield from. + + The transformations are only valid for ASTs that have passed through the + desugaring pass; only ast.Expr(value=ast.Call(...)) and + ast.Assign(targets=..., value=ast.Call(..)) nodes are transformed here. + """ + + def visit_Assign(self, node: ast.Assign) -> ast.stmt: + if not isinstance(node.value, ast.Call): + return node + assign_stmt = ast.Assign(targets=node.targets) + return self._build_call_gadget(node.value, assign_stmt) + + def visit_Expr(self, node: ast.Expr) -> ast.stmt: + if not isinstance(node.value, ast.Call): + return node + return self._build_call_gadget(node.value) + + def _build_call_gadget( + self, fn_call: ast.Call, assign: ast.Assign | None = None + ) -> ast.stmt: + fn = fn_call.func + args = ast.List(elts=fn_call.args, ctx=ast.Load()) + if fn_call.keywords: + kwargs: ast.expr = ast.Call( + func=ast.Name(id="dict", ctx=ast.Load()), + args=[], + keywords=fn_call.keywords, + ) + else: + kwargs = ast.Constant(value=None) + + compiled_fn = ast.Name(id="_multicolor_compiled_fn", ctx=ast.Store()) + compiled_fn_call = ast.Call( + func=ast.Name(id="_multicolor_compiled_fn", ctx=ast.Load()), + args=fn_call.args, + keywords=fn_call.keywords, + ) + + if assign: + assign.value = ast.Name(id="_multicolor_result", ctx=ast.Load()) + assign_result: ast.stmt = assign + else: + assign_result = ast.Pass() + + result = rewrite_template( + """ +if hasattr(__fn__, "_multicolor_yield_type"): + _multicolor_result = yield _multicolor_custom_yield(type=__fn__._multicolor_yield_type, args=__args__, kwargs=__kwargs__) + __assign_result__ +else: + _multicolor_result = None + try: + __compiled_fn__, _multicolor_color = _multicolor_compile(__fn__, _multicolor_decorator, _multicolor_cache_key) + except _multicolor_no_source_error: + _multicolor_result = __fn_call__ + else: + _multicolor_generator = __compiled_fn_call__ + if _multicolor_color == _multicolor_generator_color: + _multicolor_result = [] + for _multicolor_yield in _multicolor_generator: + if isinstance(_multicolor_yield, _multicolor_generator_yield): + _multicolor_result.append(_multicolor_yield.value) + else: + yield _multicolor_yield + else: + _multicolor_result = yield from _multicolor_generator + finally: + __assign_result__ + """, + expressions=dict( + __fn__=fn, + __fn_call__=fn_call, + __args__=args, + __kwargs__=kwargs, + __compiled_fn__=compiled_fn, + __compiled_fn_call__=compiled_fn_call, + ), + statements=dict(__assign_result__=assign_result), + ) + + return result[0] diff --git a/src/dispatch/experimental/multicolor/desugar.py b/src/dispatch/experimental/multicolor/desugar.py new file mode 100644 index 00000000..ab1d62d3 --- /dev/null +++ b/src/dispatch/experimental/multicolor/desugar.py @@ -0,0 +1,624 @@ +import ast +import sys +from typing import cast + + +def desugar_function(node: ast.FunctionDef) -> ast.FunctionDef: + """Desugar a function to simplify subsequent AST transformations.""" + node.body = Desugar().desugar(node.body) + ast.fix_missing_locations(node) + return node + + +class Desugar: + """The desugar pass simplifies AST transformations that must replace an + expression (e.g. a function call) with a statement (e.g. an if statement + or for loop) in a function definition. + + The pass recursively simplifies control flow and compound expressions + in a function definition such that: + - expressions that are children of statements either have no children, or + only have children of type ast.Name and/or ast.Constant + - those parent expressions are either part of an ast.Expr(value=expr) + statement or an ast.Assign(value=expr) statement + + The pass does not recurse into lambda expressions, or nested function or + class definitions. + """ + + def __init__(self): + self.name_count = 0 + + def desugar(self, stmts: list[ast.stmt]) -> list[ast.stmt]: + return self._desugar_stmts(stmts) + + def _desugar_stmt(self, stmt: ast.stmt) -> tuple[ast.stmt, list[ast.stmt]]: + deps: list[ast.stmt] = [] + match stmt: + # Pass + case ast.Pass(): + pass + + # Break + case ast.Break(): + pass + + # Continue + case ast.Continue(): + pass + + # Import(alias* names) + case ast.Import(): + pass + + # ImportFrom(identifier? module, alias* names, int? level) + case ast.ImportFrom(): + pass + + # Nonlocal(identifier* names) + case ast.Nonlocal(): + pass + + # Global(identifier* names) + case ast.Global(): + pass + + # Return(expr? value) + case ast.Return(): + if stmt.value is not None: + stmt.value, deps = self._desugar_expr(stmt.value) + + # Expr(expr value) + case ast.Expr(): + stmt.value, deps = self._desugar_expr(stmt.value, expr_stmt=True) + + # Assert(expr test, expr? msg) + case ast.Assert(): + stmt.test, deps = self._desugar_expr(stmt.test) + if stmt.msg is not None: + stmt.msg, msg_deps = self._desugar_expr(stmt.msg) + deps.extend(msg_deps) + + # Assign(expr* targets, expr value, string? type_comment) + case ast.Assign(): + stmt.targets, deps = self._desugar_exprs(stmt.targets) + stmt.value, value_deps = self._desugar_expr(stmt.value) + deps.extend(value_deps) + + # AugAssign(expr target, operator op, expr value) + case ast.AugAssign(): + target = cast( + ast.expr, stmt.target + ) # ast.Name | ast.Attribute | ast.Subscript + target, deps = self._desugar_expr(target) + stmt.target = cast(ast.Name | ast.Attribute | ast.Subscript, target) + stmt.value, value_deps = self._desugar_expr(stmt.value) + deps.extend(value_deps) + + # AnnAssign(expr target, expr annotation, expr? value, int simple) + case ast.AnnAssign(): + target = cast( + ast.expr, stmt.target + ) # ast.Name | ast.Attribute | ast.Subscript + target, deps = self._desugar_expr(target) + stmt.target = cast(ast.Name | ast.Attribute | ast.Subscript, target) + stmt.annotation, annotation_deps = self._desugar_expr(stmt.annotation) + deps.extend(annotation_deps) + if stmt.value is not None: + stmt.value, value_deps = self._desugar_expr(stmt.value) + deps.extend(value_deps) + + # Delete(expr* targets) + case ast.Delete(): + stmt.targets, deps = self._desugar_exprs(stmt.targets) + + # Raise(expr? exc, expr? cause) + case ast.Raise(): + if stmt.exc is not None: + stmt.exc, exc_deps = self._desugar_expr(stmt.exc) + deps.extend(exc_deps) + if stmt.cause is not None: + stmt.cause, cause_deps = self._desugar_expr(stmt.cause) + deps.extend(cause_deps) + + # If(expr test, stmt* body, stmt* orelse) + case ast.If(): + stmt.test, deps = self._desugar_expr(stmt.test) + stmt.body = self._desugar_stmts(stmt.body) + stmt.orelse = self._desugar_stmts(stmt.orelse) + + # While(expr test, stmt* body, stmt* orelse) + case ast.While(): + stmt.test, deps = self._desugar_expr(stmt.test) + stmt.body = self._desugar_stmts(stmt.body) + stmt.orelse = self._desugar_stmts(stmt.orelse) + + # For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) + case ast.For(): + stmt.target, deps = self._desugar_expr(stmt.target) + stmt.iter, iter_deps = self._desugar_expr(stmt.iter) + deps.extend(iter_deps) + stmt.body = self._desugar_stmts(stmt.body) + stmt.orelse = self._desugar_stmts(stmt.orelse) + + # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) + case ast.AsyncFor(): + stmt.target, deps = self._desugar_expr(stmt.target) + stmt.iter, iter_deps = self._desugar_expr(stmt.iter) + deps.extend(iter_deps) + stmt.body = self._desugar_stmts(stmt.body) + stmt.orelse = self._desugar_stmts(stmt.orelse) + + # Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody) + case ast.Try(): + stmt.body = self._desugar_stmts(stmt.body) + stmt.handlers, deps = self._desugar_except_handlers(stmt.handlers) + stmt.orelse = self._desugar_stmts(stmt.orelse) + stmt.finalbody = self._desugar_stmts(stmt.finalbody) + + # Match(expr subject, match_case* cases) + case ast.Match(): + stmt.subject, deps = self._desugar_expr(stmt.subject) + stmt.cases, match_case_deps = self._desugar_match_cases(stmt.cases) + deps.extend(match_case_deps) + + # With(withitem* items, stmt* body, string? type_comment) + case ast.With(): + while len(stmt.items) > 1: + last = stmt.items.pop() + stmt.body = [ast.With(items=[last], body=stmt.body)] + + stmt.items, deps = self._desugar_withitems(stmt.items) + stmt.body = self._desugar_stmts(stmt.body) + + # AsyncWith(withitem* items, stmt* body, string? type_comment) + case ast.AsyncWith(): + while len(stmt.items) > 1: + last = stmt.items.pop() + stmt.body = [ast.AsyncWith(items=[last], body=stmt.body)] + + stmt.items, deps = self._desugar_withitems(stmt.items) + stmt.body = self._desugar_stmts(stmt.body) + + # FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, expr? returns, string? type_comment) + case ast.FunctionDef(): + pass # do not recurse + + # AsyncFunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, expr? returns, string? type_comment) + case ast.AsyncFunctionDef(): + pass # do not recurse + + # ClassDef(identifier name, expr* bases, keyword* keywords, stmt* body, expr* decorator_list) + case ast.ClassDef(): + pass # do not recurse + + case _: + # Handle nodes added after Python 3.10. + if sys.version_info >= (3, 11) and isinstance(stmt, ast.TryStar): + # TryStar(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody) + stmt.body = self._desugar_stmts(stmt.body) + stmt.handlers, deps = self._desugar_except_handlers(stmt.handlers) + stmt.orelse = self._desugar_stmts(stmt.orelse) + stmt.finalbody = self._desugar_stmts(stmt.finalbody) + else: + raise NotImplementedError(f"desugar {stmt}") + + return stmt, deps + + def _desugar_expr( + self, expr: ast.expr, expr_stmt=False + ) -> tuple[ast.expr, list[ast.stmt]]: + # These cases have no nested expressions or statements. Return + # early so that no superfluous temporaries are generated. + if isinstance(expr, ast.Name): + # Name(identifier id, expr_context ctx) + return expr, [] + elif isinstance(expr, ast.Constant): + # Constant(constant value, string? kind) + return expr, [] + elif isinstance(expr, ast.Attribute) and isinstance(expr.value, ast.Name): + # Attribute(expr value, identifier attr, expr_context ctx) + return expr, [] + + deps: list[ast.stmt] = [] + wrapper = None + create_temporary = not expr_stmt + is_store = False + match expr: + # Call(expr func, expr* args, keyword* keywords) + case ast.Call(): + expr.func, deps = self._desugar_expr(expr.func) + expr.args, args_deps = self._desugar_exprs(expr.args) + deps.extend(args_deps) + expr.keywords, keywords_deps = self._desugar_keywords(expr.keywords) + deps.extend(keywords_deps) + + # BinOp(expr left, operator op, expr right) + case ast.BinOp(): + expr.left, deps = self._desugar_expr(expr.left) + expr.right, right_deps = self._desugar_expr(expr.right) + deps.extend(right_deps) + + # UnaryOp(unaryop op, expr operand) + case ast.UnaryOp(): + expr.operand, deps = self._desugar_expr(expr.operand) + + # BoolOp(boolop op, expr* values) + case ast.BoolOp(): + expr.values, deps = self._desugar_exprs(expr.values) + + # Tuple(expr* elts, expr_context ctx) + case ast.Tuple(): + expr.elts, deps = self._desugar_exprs(expr.elts) + is_store = isinstance(expr.ctx, ast.Store) + + # List(expr* elts, expr_context ctx) + case ast.List(): + expr.elts, deps = self._desugar_exprs(expr.elts) + is_store = isinstance(expr.ctx, ast.Store) + + # Set(expr* elts) + case ast.Set(): + expr.elts, deps = self._desugar_exprs(expr.elts) + + # Dict(expr* keys, expr* values) + case ast.Dict(): + for i, key in enumerate(expr.keys): + if key is not None: + key, key_deps = self._desugar_expr(key) + deps.extend(key_deps) + expr.keys[i] = key + expr.values, values_deps = self._desugar_exprs(expr.values) + deps.extend(values_deps) + + # Starred(expr value, expr_context ctx) + case ast.Starred(): + expr.value, deps = self._desugar_expr(expr.value) + is_store = isinstance(expr.ctx, ast.Store) + create_temporary = False + + # Compare(expr left, cmpop* ops, expr* comparators) + case ast.Compare(): + expr.left, deps = self._desugar_expr(expr.left) + expr.comparators, comparators_deps = self._desugar_exprs( + expr.comparators + ) + deps.extend(comparators_deps) + + # NamedExpr(expr target, expr value) + case ast.NamedExpr(): + target = cast(ast.expr, expr.target) # ast.Name + target, deps = self._desugar_expr(target) + expr.target = cast(ast.Name, target) + expr.value, value_deps = self._desugar_expr(expr.value) + deps.extend(value_deps) + + # We need to preserve the assignment so that the target is accessible + # from subsequent expressions/statements. ast.NamedExpr isn't valid as + # a standalone a statement, so we need to convert to ast.Assign. + deps.append(ast.Assign(targets=[expr.target], value=expr.value)) + expr = expr.target + + # Lambda(arguments args, expr body) + case ast.Lambda(): + pass # do not recurse + + # Await(expr value) + case ast.Await(): + expr.value, deps = self._desugar_expr(expr.value) + + # Yield(expr? value) + case ast.Yield(): + if expr.value is not None: + expr.value, deps = self._desugar_expr(expr.value) + + # YieldFrom(expr value) + case ast.YieldFrom(): + expr.value, deps = self._desugar_expr(expr.value) + + # JoinedStr(expr* values) + case ast.JoinedStr(): + expr.values, deps = self._desugar_exprs(expr.values) + + # FormattedValue(expr value, int conversion, expr? format_spec) + case ast.FormattedValue(): + expr.value, deps = self._desugar_expr(expr.value) + if expr.format_spec is not None: + expr.format_spec, format_spec_deps = self._desugar_expr( + expr.format_spec + ) + deps.extend(format_spec_deps) + + conversion = expr.conversion + format_spec = expr.format_spec + expr = expr.value + create_temporary = False + + def wrapper(value): + return ast.FormattedValue( + value=value, conversion=conversion, format_spec=format_spec + ) + + # Attribute(expr value, identifier attr, expr_context ctx) + case ast.Attribute(): + expr.value, deps = self._desugar_expr(expr.value) + is_store = isinstance(expr.ctx, ast.Store) + + # Subscript(expr value, expr slice, expr_context ctx) + case ast.Subscript(): + expr.value, deps = self._desugar_expr(expr.value) + expr.slice, slice_deps = self._desugar_expr(expr.slice) + deps.extend(slice_deps) + is_store = isinstance(expr.ctx, ast.Store) + + # Slice(expr? lower, expr? upper, expr? step) + case ast.Slice(): + if expr.lower is not None: + expr.lower, lower_deps = self._desugar_expr(expr.lower) + deps.extend(lower_deps) + if expr.upper is not None: + expr.upper, upper_deps = self._desugar_expr(expr.upper) + deps.extend(upper_deps) + if expr.step is not None: + expr.step, step_deps = self._desugar_expr(expr.step) + deps.extend(step_deps) + is_store = True + + # IfExp(expr test, expr body, expr orelse) + case ast.IfExp(): + tmp = self._new_name() + if_stmt, deps = self._desugar_stmt( + ast.If( + test=expr.test, + body=[ + ast.Assign( + targets=[ast.Name(id=tmp, ctx=ast.Store())], + value=expr.body, + ) + ], + orelse=[ + ast.Assign( + targets=[ast.Name(id=tmp, ctx=ast.Store())], + value=expr.orelse, + ) + ], + ) + ) + deps.append(if_stmt) + expr = ast.Name(id=tmp, ctx=ast.Load()) + create_temporary = False + + # ListComp(expr elt, comprehension* generators) + case ast.ListComp(): + tmp = self._new_name() + + deps = [ + ast.Assign( + targets=[ast.Name(id=tmp, ctx=ast.Store())], + value=ast.List(elts=[], ctx=ast.Load()), + ) + ] + + inner_statement: ast.stmt = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=tmp, ctx=ast.Load()), + attr="append", + ctx=ast.Load(), + ), + args=[expr.elt], + keywords=[], + ) + ) + + deps += self._desugar_comprehensions(expr.generators, inner_statement) + expr = ast.Name(id=tmp, ctx=ast.Load()) + create_temporary = False + + # SetComp(expr elt, comprehension* generators) + case ast.SetComp(): + tmp = self._new_name() + + deps = [ + ast.Assign( + targets=[ast.Name(id=tmp, ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="set", ctx=ast.Load()), + args=[], + keywords=[], + ), + ) + ] + + inner_statement = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=tmp, ctx=ast.Load()), + attr="add", + ctx=ast.Load(), + ), + args=[expr.elt], + keywords=[], + ) + ) + + deps += self._desugar_comprehensions(expr.generators, inner_statement) + expr = ast.Name(id=tmp, ctx=ast.Load()) + create_temporary = False + + # DictComp(expr key, expr value, comprehension* generators) + case ast.DictComp(): + tmp = self._new_name() + + deps = [ + ast.Assign( + targets=[ast.Name(id=tmp, ctx=ast.Store())], + value=ast.Dict(keys=[], values=[]), + ) + ] + + inner_statement = ast.Assign( + targets=[ + ast.Subscript( + value=ast.Name(id=tmp, ctx=ast.Store()), + slice=expr.key, + ctx=ast.Store(), + ) + ], + value=expr.value, + ) + + deps += self._desugar_comprehensions(expr.generators, inner_statement) + expr = ast.Name(id=tmp, ctx=ast.Load()) + create_temporary = False + + # GeneratorExp(expr elt, comprehension* generators) + case ast.GeneratorExp(): + tmp = self._new_name() + inner_statement = ast.Expr(value=ast.Yield(value=expr.elt)) + body = self._desugar_comprehensions(expr.generators, inner_statement) + deps = [ + ast.FunctionDef( + name=tmp, + args=ast.arguments( + args=[], + posonlyargs=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=body, + decorator_list=[], + ) + ] + expr = ast.Call( + func=ast.Name(id=tmp, ctx=ast.Load()), args=[], keywords=[] + ) + + case _: + raise NotImplementedError(f"desugar {expr}") + + if create_temporary and not is_store: + tmp = self._new_name() + deps.append( + ast.Assign(targets=[ast.Name(id=tmp, ctx=ast.Store())], value=expr) + ) + expr = ast.Name(id=tmp, ctx=ast.Load()) + + if wrapper is not None: + expr = wrapper(expr) + + return expr, deps + + def _desugar_stmts(self, stmts: list[ast.stmt]) -> list[ast.stmt]: + desugared = [] + for stmt in stmts: + stmt, deps = self._desugar_stmt(stmt) + desugared.extend(deps) + desugared.append(stmt) + return desugared + + def _desugar_exprs( + self, exprs: list[ast.expr] + ) -> tuple[list[ast.expr], list[ast.stmt]]: + desugared = [] + deps = [] + for expr in exprs: + expr, expr_deps = self._desugar_expr(expr) + deps.extend(expr_deps) + desugared.append(expr) + return desugared, deps + + def _desugar_keywords( + self, keywords: list[ast.keyword] + ) -> tuple[list[ast.keyword], list[ast.stmt]]: + # keyword(identifier? arg, expr value) + desugared = [] + deps = [] + for keyword in keywords: + keyword.value, keyword_deps = self._desugar_expr(keyword.value) + deps.extend(keyword_deps) + desugared.append(keyword) + return desugared, deps + + def _desugar_except_handlers( + self, handlers: list[ast.ExceptHandler] + ) -> tuple[list[ast.ExceptHandler], list[ast.stmt]]: + # excepthandler = ExceptHandler(expr? type, identifier? name, stmt* body) + desugared = [] + deps: list[ast.stmt] = [] + for handler in handlers: + if handler.type is not None: + # FIXME: exception type exprs need special handling. Each handler's + # type expr is evaluated one at a time until there's a match. The + # remaining handler's type exprs are not evaluated. + # handler.type, type_deps = self._desugar_expr(handler.type) + # deps.extend(type_deps) + pass + handler.body = self._desugar_stmts(handler.body) + desugared.append(handler) + return desugared, deps + + def _desugar_match_cases( + self, cases: list[ast.match_case] + ) -> tuple[list[ast.match_case], list[ast.stmt]]: + # match_case(pattern pattern, expr? guard, stmt* body) + desugared: list[ast.match_case] = [] + deps: list[ast.stmt] = [] + for case in cases: + if case.guard is not None: + # FIXME: match guards need special handling; they shouldn't be evaluated + # unless the pattern matches. + # case.guard, guard_deps = self._desugar_expr(case.guard) + # deps.extend(guard_deps) + pass + case.body = self._desugar_stmts(case.body) + desugared.append(case) + # You're supposed to be able to pass the AST root to this function + # to have it repair (fill in missing) line numbers and such. It + # seems there's a bug where it doesn't recurse into match cases. + # Work around the issue by manually fixing the match case here. + ast.fix_missing_locations(case) + return desugared, deps + + def _desugar_withitems( + self, withitems: list[ast.withitem] + ) -> tuple[list[ast.withitem], list[ast.stmt]]: + # withitem(expr context_expr, expr? optional_vars) + desugared = [] + deps = [] + for withitem in withitems: + withitem.context_expr, context_expr_deps = self._desugar_expr( + withitem.context_expr + ) + deps.extend(context_expr_deps) + if withitem.optional_vars is not None: + withitem.optional_vars, optional_vars_deps = self._desugar_expr( + withitem.optional_vars + ) + deps.extend(optional_vars_deps) + desugared.append(withitem) + return desugared, deps + + def _desugar_comprehensions( + self, comprehensions: list[ast.comprehension], inner_statement: ast.stmt + ) -> list[ast.stmt]: + # comprehension(expr target, expr iter, expr* ifs, int is_async) + stmt = inner_statement + while comprehensions: + last_for = comprehensions.pop() + while last_for.ifs: + test = last_for.ifs.pop() + stmt = ast.If(test=test, body=[stmt], orelse=[]) + cls = ast.AsyncFor if last_for.is_async else ast.For + stmt = cls( + target=last_for.target, iter=last_for.iter, body=[stmt], orelse=[] + ) + + stmt, deps = self._desugar_stmt(stmt) + return deps + [stmt] + + def _new_name(self) -> str: + name = f"_v{self.name_count}" + self.name_count += 1 + return name diff --git a/src/dispatch/experimental/multicolor/generator.py b/src/dispatch/experimental/multicolor/generator.py new file mode 100644 index 00000000..1fba02fc --- /dev/null +++ b/src/dispatch/experimental/multicolor/generator.py @@ -0,0 +1,48 @@ +import ast + + +def is_generator(fn_def: ast.FunctionDef) -> bool: + """Check if a function definition defines a generator.""" + yield_counter = YieldCounter() + yield_counter.visit(fn_def) + return yield_counter.count > 0 + + +class YieldCounter(ast.NodeVisitor): + """Walks an ast.FunctionDef to count yield and yield from statements. + + Yields from nested function/class definitions are not counted. + + The resulting count can be used to determine if the input function is + a generator or not.""" + + def __init__(self): + self.count = 0 + self.depth = 0 + + def visit_Yield(self, node): + self.count += 1 + + def visit_YieldFrom(self, node): + self.count += 1 + + def visit_FunctionDef(self, node): + self._visit_nested(node) + + def visit_AsyncFunctionDef(self, node): + self._visit_nested(node) + + def visit_ClassDef(self, node): + self._visit_nested(node) + + def _visit_nested(self, node): + self.depth += 1 + if self.depth > 1: + return # do not recurse + self.generic_visit(node) + self.depth -= 1 + + +def empty_generator(): + if False: + yield diff --git a/src/dispatch/experimental/multicolor/parse.py b/src/dispatch/experimental/multicolor/parse.py new file mode 100644 index 00000000..a49ab75a --- /dev/null +++ b/src/dispatch/experimental/multicolor/parse.py @@ -0,0 +1,48 @@ +import ast +import inspect +from typing import cast +from types import FunctionType + + +def parse_function(fn: FunctionType) -> tuple[ast.Module, ast.FunctionDef]: + """Parse an AST from a function definition.""" + try: + src = inspect.getsource(fn) + except TypeError as e: + # The source is not always available. For example, the function + # may be defined in a C extension, or may be a builtin function. + raise NoSourceError from e + except OSError as e: + raise NoSourceError from e + + try: + module = ast.parse(src) + except IndentationError: + src = repair_indentation(src) + module = ast.parse(src) + + fn_def = cast(ast.FunctionDef, module.body[0]) + return module, fn_def + + +class NoSourceError(RuntimeError): + """Error that occurs when a function AST is not available because + the (Python) source code is not available.""" + + +def repair_indentation(src: str) -> str: + """Repair (remove excess) indentation from the source of a function + definition that's nested within a class or function.""" + lines = src.split("\n") + head = lines[0] + indent_len = len(head) - len(head.lstrip()) + indent = head[:indent_len] + for i in range(len(lines)): + if len(lines[i]) == 0: + continue + if not lines[i].startswith(indent): + raise IndentationError( + f"inconsistent indentation '{head}' vs. '{lines[i]}'" + ) + lines[i] = lines[i][indent_len:] + return "\n".join(lines) diff --git a/src/dispatch/experimental/multicolor/template.py b/src/dispatch/experimental/multicolor/template.py new file mode 100644 index 00000000..190a792c --- /dev/null +++ b/src/dispatch/experimental/multicolor/template.py @@ -0,0 +1,35 @@ +import ast + + +def rewrite_template( + template: str, expressions: dict[str, ast.expr], statements: dict[str, ast.stmt] +) -> list[ast.stmt]: + """Create an AST by parsing a template string and then replacing + temporary variables with the specified AST nodes.""" + root = ast.parse(template) + root = NameTransformer(expressions=expressions, statements=statements).visit(root) + return root.body + + +class NameTransformer(ast.NodeTransformer): + """Replace ast.Name nodes in an AST.""" + + def __init__( + self, expressions: dict[str, ast.expr], statements: dict[str, ast.stmt] + ): + self.expressions = expressions + self.statements = statements + + def visit_Name(self, node): + try: + return self.expressions[node.id] + except KeyError: + return node + + def visit_Expr(self, node): + if not isinstance(node.value, ast.Name): + return node + try: + return self.statements[node.value.id] + except KeyError: + return node diff --git a/src/dispatch/experimental/multicolor/yields.py b/src/dispatch/experimental/multicolor/yields.py new file mode 100644 index 00000000..ce5fb505 --- /dev/null +++ b/src/dispatch/experimental/multicolor/yields.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from types import FunctionType +from typing import Any + + +def yields(type: Any): + """Mark a function as a custom yield point.""" + + def decorator(fn: FunctionType) -> FunctionType: + fn._multicolor_yield_type = type # type: ignore[attr-defined] + return fn + + return decorator + + +class YieldType: + """Base class for yield types.""" + + +@dataclass +class CustomYield(YieldType): + """A yield from a function marked with @yields.""" + + type: Any + args: list[Any] + kwargs: dict[str, Any] | None = None + + +@dataclass +class GeneratorYield(YieldType): + """A yield from a generator.""" + + value: Any = None diff --git a/tests/dispatch/__init__.py b/tests/dispatch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dispatch/experimental/__init__.py b/tests/dispatch/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dispatch/experimental/multicolor/__init__.py b/tests/dispatch/experimental/multicolor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dispatch/experimental/multicolor/test_compile.py b/tests/dispatch/experimental/multicolor/test_compile.py new file mode 100644 index 00000000..cbfa86c6 --- /dev/null +++ b/tests/dispatch/experimental/multicolor/test_compile.py @@ -0,0 +1,230 @@ +import time +import unittest +from types import FunctionType +from typing import Any +from dispatch.experimental.multicolor import ( + compile_function, + yields, + CustomYield, + GeneratorYield, +) +from enum import Enum + + +class YieldTypes(Enum): + SLEEP = 0 + ADD = 1 + STAR_ARGS_KWARGS = 2 + + +@yields(type=YieldTypes.SLEEP) +def sleep(seconds): + time.sleep(seconds) + + +@yields(type=YieldTypes.ADD) +def add(a: int, b: int) -> int: + return a + b + + +@yields(type=YieldTypes.STAR_ARGS_KWARGS) +def star_args_kwargs(*args, **kwargs): + pass + + +def adder(a: int, b: int) -> int: + return add(a, b) + + +def empty(): + pass + + +def identity(x): + return x + + +def identity_sleep(x): + sleep(x) + return x + + +def identity_sleep_yield(n): + for i in range(n): + sleep(i) + yield i + + +class TestCompile(unittest.TestCase): + def test_empty(self): + self.assert_yields(empty, args=[], yields=[], returns=None) + + def test_identity(self): + self.assert_yields(identity, args=[1], yields=[], returns=1) + + def test_identity_indirect(self): + def identity_indirect(x): + return identity(x) + + self.assert_yields(identity_indirect, args=[2], yields=[], returns=2) + + def test_identity_sleep(self): + yields = [CustomYield(type=YieldTypes.SLEEP, args=[1])] + self.assert_yields(identity_sleep, args=[1], yields=yields, returns=1) + + def test_identity_sleep_indirect(self): + def identity_sleep_indirect(x): + return identity_sleep(x) + + yields = [CustomYield(type=YieldTypes.SLEEP, args=[1])] + self.assert_yields(identity_sleep_indirect, args=[1], yields=yields, returns=1) + + def test_adder(self): + yields = [CustomYield(type=YieldTypes.ADD, args=[1, 2])] + self.assert_yields(adder, args=[1, 2], sends=[3], yields=yields, returns=3) + + def test_adder_indirect(self): + def adder_indirect(a, b): + return adder(a, b) + + yields = [CustomYield(type=YieldTypes.ADD, args=[1, 2])] + self.assert_yields( + adder_indirect, args=[1, 2], sends=[3], yields=yields, returns=3 + ) + + def test_star_args_kwargs_forward(self): + def star_args_kwargs_forward(*args, **kwargs): + star_args_kwargs(*args, **kwargs) + + yields = [ + CustomYield( + type=YieldTypes.STAR_ARGS_KWARGS, args=[1, 2], kwargs={"foo": "bar"} + ) + ] + self.assert_yields( + star_args_kwargs_forward, args=[1, 2], kwargs={"foo": "bar"}, yields=yields + ) + + def test_star_args_kwargs_explicit(self): + def star_args_kwargs_explicit(): + star_args_kwargs(1, 2, foo="bar") + + yields = [ + CustomYield( + type=YieldTypes.STAR_ARGS_KWARGS, args=[1, 2], kwargs={"foo": "bar"} + ) + ] + self.assert_yields(star_args_kwargs_explicit, yields=yields) + + def test_generator_yield(self): + def generator(): + yield 1 + sleep(2) + yield 3 + yield + return 4 + + yields = [ + GeneratorYield(value=1), + CustomYield(type=YieldTypes.SLEEP, args=[2]), + GeneratorYield(value=3), + GeneratorYield(), + ] + self.assert_yields(generator, yields=yields, returns=4) + + def test_generator_yield_send(self): + def generator(): + a = yield 1 + b = add(10, 20) + c = yield 3 + return a, b, c + + yields = [ + GeneratorYield(value=1), + CustomYield(type=YieldTypes.ADD, args=[10, 20]), + GeneratorYield(value=3), + ] + self.assert_yields( + generator, yields=yields, sends=[100, 30, 1000], returns=(100, 30, 1000) + ) + + def test_generator_range(self): + def generator(): + for i in range(3): + sleep(i) + yield i + + yields = [ + CustomYield(type=YieldTypes.SLEEP, args=[0]), + GeneratorYield(value=0), + CustomYield(type=YieldTypes.SLEEP, args=[1]), + GeneratorYield(value=1), + CustomYield(type=YieldTypes.SLEEP, args=[2]), + GeneratorYield(value=2), + ] + self.assert_yields(generator, yields=yields) + + def test_list_comprehensions(self): + def fn(): + return sum([identity_sleep(i) for i in range(3)]) + + yields = [ + CustomYield(type=YieldTypes.SLEEP, args=[0]), + CustomYield(type=YieldTypes.SLEEP, args=[1]), + CustomYield(type=YieldTypes.SLEEP, args=[2]), + ] + self.assert_yields(fn, yields=yields, returns=3) + + def test_list_comprehensions_2(self): + def fn(): + return sum([x for x in identity_sleep_yield(3)]) + + yields = [ + CustomYield(type=YieldTypes.SLEEP, args=[0]), + CustomYield(type=YieldTypes.SLEEP, args=[1]), + CustomYield(type=YieldTypes.SLEEP, args=[2]), + ] + self.assert_yields(fn, yields=yields, returns=3) + + def test_generator_comprehensions(self): + def fn(): + return sum(identity_sleep(i) for i in range(3)) + + yields = [ + CustomYield(type=YieldTypes.SLEEP, args=[0]), + CustomYield(type=YieldTypes.SLEEP, args=[1]), + CustomYield(type=YieldTypes.SLEEP, args=[2]), + ] + self.assert_yields(fn, yields=yields, returns=3) + + def assert_yields( + self, + fn: FunctionType, + yields: list[Any], + args: list[Any] | None = None, + kwargs: dict[str, Any] | None = None, + returns: Any = None, + sends: list[Any] | None = None, + ): + args = args if args is not None else [] + kwargs = kwargs if kwargs is not None else {} + + compiled_fn = compile_function(fn) + gen = compiled_fn(*args, **kwargs) + + actual_yields = [] + actual_returns = None + try: + i = 0 + while True: + if i == 0 or not sends: + value = gen.send(None) + else: + value = gen.send(sends[i - 1]) + actual_yields.append(value) + i += 1 + except StopIteration as e: + actual_returns = e.value + + self.assertListEqual(actual_yields, yields) + self.assertEqual(actual_returns, returns) diff --git a/tests/dispatch/experimental/multicolor/test_desugar.py b/tests/dispatch/experimental/multicolor/test_desugar.py new file mode 100644 index 00000000..f1b2bc27 --- /dev/null +++ b/tests/dispatch/experimental/multicolor/test_desugar.py @@ -0,0 +1,749 @@ +import ast +import unittest +from types import FunctionType +from dispatch.experimental.multicolor.parse import parse_function +from dispatch.experimental.multicolor.desugar import desugar_function + + +# Disable lint checks: +# ruff: noqa + + +class TestDesugar(unittest.TestCase): + def test_pass(self): + def fn(): + pass + + self.assert_desugar_is_noop(fn) + + def test_import(self): + def fn(): + import ast + + self.assert_desugar_is_noop(fn) + + def test_import_from(self): + def fn(): + from ast import parse + + self.assert_desugar_is_noop(fn) + + def test_global(self): + def fn(): + global ast + + self.assert_desugar_is_noop(fn) + + def test_expr_stmt(self): + def fn(): + identity(1) + + self.assert_desugar_is_noop(fn) + + def test_return_empty(self): + def fn(): + return + + self.assert_desugar_is_noop(fn) + + def test_return_call(self): + def before(): + return identity(1) + + def after(): + _v0 = identity(1) + return _v0 + + self.assert_desugared(before, after) + + def test_return_bin_op(self): + def before(): + return identity(1) + identity(2) + + def after(): + _v0 = identity(1) + _v1 = identity(2) + _v2 = _v0 + _v1 + return _v2 + + self.assert_desugared(before, after) + + def test_return_unary_op(self): + def before(): + return not identity(1) + + def after(): + _v0 = identity(1) + _v1 = not _v0 + return _v1 + + self.assert_desugared(before, after) + + def test_return_bool_op(self): + def before(): + return identity(1) and identity(2) + + def after(): + _v0 = identity(1) + _v1 = identity(2) + _v2 = _v0 and _v1 + return _v2 + + self.assert_desugared(before, after) + + def test_compound_literals(self): + def before(): + foo = [identity(1), identity(2), *identity(3)] + bar = {identity(4)} + baz = {identity(5): identity(6), **identity(7)} + + def after(): + _v0 = identity(1) + _v1 = identity(2) + _v2 = identity(3) + _v3 = [_v0, _v1, *_v2] + foo = _v3 + _v4 = identity(4) + _v5 = {_v4} + bar = _v5 + _v6 = identity(5) + _v7 = identity(6) + _v8 = identity(7) + _v9 = {_v6: _v7, **_v8} + baz = _v9 + + self.assert_desugared(before, after) + + def test_assert_bool(self): + def fn(): + assert True + assert True, "message" + + self.assert_desugar_is_noop(fn) + + def test_assert_call(self): + def before(): + assert identity(1) + assert identity(2), "message" + assert identity(3), identity(4) + + def after(): + _v0 = identity(1) + assert _v0 + _v1 = identity(2) + assert _v1, "message" + _v2 = identity(3) + _v3 = identity(4) + assert _v2, _v3 + + self.assert_desugared(before, after) + + def test_assign_name_constant(self): + def fn(): + foo = 1 + bar: int = 1 # type: ignore[annotation-unchecked] + foo += 1 + + self.assert_desugar_is_noop(fn) + + def test_assign_call(self): + def before(): + foo = identity(1) + bar: int = identity(2) # type: ignore[annotation-unchecked] + foo += identity(3) + + def after(): + _v0 = identity(1) + foo = _v0 + _v1 = identity(2) + bar: int = _v1 # type: ignore[annotation-unchecked] + _v2 = identity(3) + foo += _v2 + + self.assert_desugared(before, after) + + def test_assign_tuple(self): + def before(): + foo, bar = 1, 2 + + def after(): + _v0 = (1, 2) + foo, bar = _v0 + + self.assert_desugared(before, after) + + def test_assign_tuple_call(self): + def before(): + foo, bar = identity(1), identity(2) + + def after(): + _v0 = identity(1) + _v1 = identity(2) + _v2 = (_v0, _v1) + foo, bar = _v2 + + self.assert_desugared(before, after) + + def test_if_noops(self): + def fn(): + if True: + pass + if False: + pass + elif True: + pass + if True: + pass + elif True: + pass + else: + pass + + self.assert_desugar_is_noop(fn) + + def test_if(self): + def before(): + if identity(1) == 1: + return identity(2) + else: + return identity(3) + + def after(): + _v0 = identity(1) + _v1 = _v0 == 1 + if _v1: + _v2 = identity(2) + return _v2 + else: + _v3 = identity(3) + return _v3 + + self.assert_desugared(before, after) + + def test_nested_ifs(self): + def before(): + if identity(1) == 1: + return identity(2) + elif identity(3) == 3: + return identity(4) + else: + return identity(5) + + def after(): + _v0 = identity(1) + _v1 = _v0 == 1 + if _v1: + _v2 = identity(2) + return _v2 + else: + _v3 = identity(3) + _v4 = _v3 == 3 + if _v4: + _v5 = identity(4) + return _v5 + else: + _v6 = identity(5) + return _v6 + + self.assert_desugared(before, after) + + def test_named_expr(self): + def before(): + if (n := identity(1)) == 1: + return n + + def after(): + _v0 = identity(1) + n = _v0 + _v1 = n + _v2 = _v1 == 1 + if _v2: + return n + + self.assert_desugared(before, after) + + def test_while_noops(self): + def fn(): + while True: + break + while False: + continue + else: + return + + self.assert_desugar_is_noop(fn) + + def test_while(self): + def before(): + while identity(1) == identity(2): + return identity(3) + else: + return identity(4) + + def after(): + _v0 = identity(1) + _v1 = identity(2) + _v2 = _v0 == _v1 + while _v2: + _v3 = identity(3) + return _v3 + else: + _v4 = identity(4) + return _v4 + + self.assert_desugared(before, after) + + def test_for_noops(self): + def fn(x=[]): + for i in x: + pass + for i in x: + break + else: + return + + self.assert_desugar_is_noop(fn) + + def test_for(self): + def before(): + for i, x in enumerate(identity(1)): + return identity(2) + else: + return identity(3) + + def after(): + _v0 = identity(1) + _v1 = enumerate(_v0) + for i, x in _v1: + _v2 = identity(2) + return _v2 + else: + _v3 = identity(3) + return _v3 + + self.assert_desugared(before, after) + + def test_async_for_noops(self): + async def fn(x=[]): + async for i in x: + pass + async for i in x: + break + else: + return + + self.assert_desugar_is_noop(fn) + + def test_async_for(self): + async def before(): + async for i, x in identity(1): + return identity(2) + else: + return identity(3) + + async def after(): + _v0 = identity(1) + async for i, x in _v0: + _v1 = identity(2) + return _v1 + else: + _v2 = identity(3) + return _v2 + + self.assert_desugared(before, after) + + def test_try(self): + def before(): + try: + return identity(1) + except RuntimeError as e: + return identity(2) + else: + return identity(3) + finally: + return identity(4) + + def after(): + try: + _v0 = identity(1) + return _v0 + except RuntimeError as e: + _v1 = identity(2) + return _v1 + else: + _v2 = identity(3) + return _v2 + finally: + _v3 = identity(4) + return _v3 + + self.assert_desugared(before, after) + + def test_try_type_expr(self): + def before(): + try: + pass + except RuntimeError as a: + pass + except identity(1) as b: + pass + + def after(): + try: + pass + except RuntimeError as a: + pass + except identity(1) as b: # FIXME: desugar the type expr + pass + + self.assert_desugared(before, after) + + def test_match(self): + def before(): + match identity(1): + case ast.Expr(): + return identity(3) + case ast.Call() if identity(2): + pass + + def after(): + _v0 = identity(1) + match _v0: + case ast.Expr(): # this is a pattern, not an expression + _v1 = identity(3) + return _v1 + case ast.Call() if identity(2): # FIXME: desugar the guard + pass + + self.assert_desugared(before, after) + + def test_with(self): + def before(): + with identity(1) as x: + return identity(2) + + def after(): + _v0 = identity(1) + with _v0 as x: + _v1 = identity(2) + return _v1 + + self.assert_desugared(before, after) + + def test_nested_with(self): + def before(): + with identity(1) as x, identity(x) as y, identity(y) as z: + return identity(2) + + def after(): + _v0 = identity(1) + with _v0 as x: + _v1 = identity(x) + with _v1 as y: + _v2 = identity(y) + with _v2 as z: + _v3 = identity(2) + return _v3 + + self.assert_desugared(before, after) + + def test_async_with(self): + async def before(): + async with identity(1) as x: + return identity(2) + + async def after(): + _v0 = identity(1) + async with _v0 as x: + _v1 = identity(2) + return _v1 + + self.assert_desugared(before, after) + + def test_nested_async_with(self): + async def before(): + async with identity(1) as x, identity(x) as y, identity(y) as z: + return identity(2) + + async def after(): + _v0 = identity(1) + async with _v0 as x: + _v1 = identity(x) + async with _v1 as y: + _v2 = identity(y) + async with _v2 as z: + _v3 = identity(2) + return _v3 + + self.assert_desugared(before, after) + + def test_await(self): + async def before(): + return await identity(1) + + async def after(): + _v0 = identity(1) + _v1 = await _v0 + return _v1 + + self.assert_desugared(before, after) + + def test_yield(self): + def before(): + yield + yield identity(1) + return (yield identity(2)) + + def after(): + yield + _v0 = identity(1) + yield _v0 + _v1 = identity(2) + _v2 = yield _v1 + return _v2 + + self.assert_desugared(before, after) + + def test_yield_from(self): + def before(): + yield from identity(1) + return (yield from identity(2)) + + def after(): + _v0 = identity(1) + yield from _v0 + _v1 = identity(2) + _v2 = yield from _v1 + return _v2 + + self.assert_desugared(before, after) + + def test_f_strings(self): + def before(): + print(f"a {identity(1)} b {identity(2)} c") + + def after(): + _v0 = identity(1) + _v1 = identity(2) + _v2 = f"a {_v0} b {_v1} c" + print(_v2) + + self.assert_desugared(before, after) + + def test_attribute(self): + def before(a): + foo = a.b + + a.b = True + + foo = identity(1).foo + identity(2).foo = True + + def after(a): + foo = a.b + + a.b = True + + _v0 = identity(1) + _v1 = _v0.foo + foo = _v1 + + _v2 = identity(2) + _v2.foo = True + + self.assert_desugared(before, after) + + def test_subscript(self): + def before(a, b): + foo = a[b] + + a[b] = True + + foo = identity(1)[identity(2)] + + identity(3)[identity(4)] = True + + def after(a, b): + _v0 = a[b] + foo = _v0 + + a[b] = True + + _v1 = identity(1) + _v2 = identity(2) + _v3 = _v1[_v2] + foo = _v3 + + _v4 = identity(3) + _v5 = identity(4) + _v4[_v5] = True + + self.assert_desugared(before, after) + + def test_slice(self): + def before(a): + foo = a[identity(1) : identity(2) : identity(3)] + + def after(a): + _v0 = identity(1) + _v1 = identity(2) + _v2 = identity(3) + _v3 = a[_v0:_v1:_v2] + foo = _v3 + + self.assert_desugared(before, after) + + def test_store_ctx(self): + def fn(a): + [foo] = a + [*foo] = a + foo = a + foo.bar = a + foo, bar = a + foo[bar] = a + + self.assert_desugar_is_noop(fn) + + def test_if_expr(self): + def before(): + foo = identity(2) if identity(1) == 1 else identity(3) + + def after(): + _v1 = identity(1) + _v2 = _v1 == 1 + if _v2: + _v3 = identity(2) + _v0 = _v3 + else: + _v4 = identity(3) + _v0 = _v4 + foo = _v0 + + self.assert_desugared(before, after) + + def test_list_comprehensions(self): + def before(y): + foo = [ + identity(z) + for x in y + if x == 1 + if x != 2 + for z in identity(x) + if identity(z) == 3 + ] + + def after(y): + _v0 = [] + for x in y: + _v1 = x == 1 + if _v1: + _v2 = x != 2 + if _v2: + _v3 = identity(x) + for z in _v3: + _v4 = identity(z) + _v5 = _v4 == 3 + if _v5: + _v6 = identity(z) + _v0.append(_v6) + foo = _v0 + + self.assert_desugared(before, after) + + def test_set_comprehension(self): + def before(y): + foo = { + identity(z) + for x in y + if x == 1 + if x != 2 + for z in identity(x) + if identity(z) == 3 + } + + def after(y): + _v0 = set() + for x in y: + _v1 = x == 1 + if _v1: + _v2 = x != 2 + if _v2: + _v3 = identity(x) + for z in _v3: + _v4 = identity(z) + _v5 = _v4 == 3 + if _v5: + _v6 = identity(z) + _v0.add(_v6) + foo = _v0 + + self.assert_desugared(before, after) + + def test_dict_comprehension(self): + self.maxDiff = 10000 + + def before(y): + foo = { + identity(z): identity(x) + for x in y + if x == 1 + if x != 2 + for z in identity(x) + if identity(z) == 3 + } + + def after(y): + _v0 = {} + for x in y: + _v1 = x == 1 + if _v1: + _v2 = x != 2 + if _v2: + _v3 = identity(x) + for z in _v3: + _v4 = identity(z) + _v5 = _v4 == 3 + if _v5: + _v6 = identity(z) + _v7 = identity(x) + _v0[_v6] = _v7 + foo = _v0 + + self.assert_desugared(before, after) + + def test_generator_comprehension(self): + def before(y): + foo = (identity(x) for x in y if x == 1) + + def after(y): + def _v0(): + for x in y: + _v1 = x == 1 + if _v1: + _v2 = identity(x) + yield _v2 + + _v3 = _v0() + foo = _v3 + + self.assert_desugared(before, after) + + def assert_desugar_is_noop(self, fn): + self.assert_desugared(fn, fn) + + def assert_desugared(self, before: FunctionType, after: FunctionType): + _, before_def = parse_function(before) + _, after_def = parse_function(after) + + before_def.name = "function" + after_def.name = "function" + + desugar_function(before_def) + + expect = ast.unparse(after_def) + actual = ast.unparse(before_def) + self.assertEqual(expect, actual) + + +def identity(x): + return x diff --git a/tests/dispatch/experimental/multicolor/test_generator.py b/tests/dispatch/experimental/multicolor/test_generator.py new file mode 100644 index 00000000..ba803c2a --- /dev/null +++ b/tests/dispatch/experimental/multicolor/test_generator.py @@ -0,0 +1,63 @@ +import unittest +from dispatch.experimental.multicolor.parse import parse_function +from dispatch.experimental.multicolor.generator import YieldCounter, is_generator + + +class TestYieldCounter(unittest.TestCase): + def test_empty(self): + def empty(): + pass + + self.assert_yield_count(empty, 0) + + def test_yield(self): + def yields(): + yield 1 + if True: + yield 2 + else: + yield 3 + + self.assert_yield_count(yields, 3) + + def test_yield_from(self): + def yields(): + yield from yields() + + self.assert_yield_count(yields, 1) + + def test_nested_function(self): + def not_a_generator(): + def nested(): + yield 1 + + return 0 + + self.assert_yield_count(not_a_generator, 0) + + def test_nested_async_function(self): + def not_a_generator(): + async def nested(): + yield 1 + + return 0 + + self.assert_yield_count(not_a_generator, 0) + + def test_nested_class(self): + def not_a_generator(): + class foo: + def nested(self): + yield 1 + + return 0 + + self.assert_yield_count(not_a_generator, 0) + + def assert_yield_count(self, fn, count): + _, fn_def = parse_function(fn) + yield_counter = YieldCounter() + yield_counter.visit(fn_def) + self.assertEqual(yield_counter.count, count) + + self.assertEqual(is_generator(fn_def), count > 0) From 477f9e2020a8ee240836d4356246c1b92dae1938 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 31 Jan 2024 11:34:53 +1000 Subject: [PATCH 2/8] Add experimental durable package --- src/dispatch/experimental/durable/__init__.py | 3 + src/dispatch/experimental/durable/_frame.c | 320 ++++++++++++++++++ src/dispatch/experimental/durable/_frame.pyi | 15 + src/dispatch/experimental/durable/durable.py | 24 ++ .../experimental/durable/generator.py | 106 ++++++ src/dispatch/experimental/durable/registry.py | 24 ++ .../dispatch/experimental/durable/__init__.py | 0 .../experimental/durable/test_frame.py | 34 ++ .../experimental/durable/test_generator.py | 59 ++++ 9 files changed, 585 insertions(+) create mode 100644 src/dispatch/experimental/durable/__init__.py create mode 100644 src/dispatch/experimental/durable/_frame.c create mode 100644 src/dispatch/experimental/durable/_frame.pyi create mode 100644 src/dispatch/experimental/durable/durable.py create mode 100644 src/dispatch/experimental/durable/generator.py create mode 100644 src/dispatch/experimental/durable/registry.py create mode 100644 tests/dispatch/experimental/durable/__init__.py create mode 100644 tests/dispatch/experimental/durable/test_frame.py create mode 100644 tests/dispatch/experimental/durable/test_generator.py diff --git a/src/dispatch/experimental/durable/__init__.py b/src/dispatch/experimental/durable/__init__.py new file mode 100644 index 00000000..919c2ff4 --- /dev/null +++ b/src/dispatch/experimental/durable/__init__.py @@ -0,0 +1,3 @@ +from .durable import durable + +__all__ = ["durable"] diff --git a/src/dispatch/experimental/durable/_frame.c b/src/dispatch/experimental/durable/_frame.c new file mode 100644 index 00000000..22c47cac --- /dev/null +++ b/src/dispatch/experimental/durable/_frame.c @@ -0,0 +1,320 @@ +#include +#include +#include + +#define PY_SSIZE_T_CLEAN +#include + +#if PY_MAJOR_VERSION != 3 || (PY_MINOR_VERSION < 11 || PY_MINOR_VERSION > 13) +# error Python 3.11-3.13 is required +#endif + +// This is a redefinition of the private/opaque struct _PyInterpreterFrame: +// https://github.com/python/cpython/blob/3.12/Include/cpython/pyframe.h#L23 +// https://github.com/python/cpython/blob/3.12/Include/internal/pycore_frame.h#L51 +typedef struct InterpreterFrame { +#if PY_MINOR_VERSION == 11 + PyFunctionObject *f_func; +#elif PY_MINOR_VERSION >= 12 + PyCodeObject *f_code; // 3.13: PyObject *f_executable + struct _PyInterpreterFrame *previous; + PyObject *f_funcobj; +#endif + PyObject *f_globals; + PyObject *f_builtins; + PyObject *f_locals; +#if PY_MINOR_VERSION == 11 + PyCodeObject *f_code; + PyFrameObject *frame_obj; + struct _PyInterpreterFrame *previous; + _Py_CODEUNIT *prev_instr; + int stacktop; + bool is_entry; +#elif PY_MINOR_VERSION >= 12 + PyFrameObject *frame_obj; + _Py_CODEUNIT *prev_instr; // 3.13: _Py_CODEUNIT *instr_ptr + int stacktop; + uint16_t return_offset; +#endif + char owner; + PyObject *localsplus[1]; +} InterpreterFrame; + +// This is a redefinition of the private/opaque PyFrameObject: +// https://github.com/python/cpython/blob/3.12/Include/pytypedefs.h#L22 +// https://github.com/python/cpython/blob/3.12/Include/internal/pycore_frame.h#L16 +// The definition is the same for Python 3.11-3.13. +typedef struct FrameObject { + PyObject_HEAD + PyFrameObject *f_back; + struct _PyInterpreterFrame *f_frame; + PyObject *f_trace; + int f_lineno; + char f_trace_lines; + char f_trace_opcodes; + char f_fast_as_locals; + PyObject *_f_frame_data[1]; +} FrameObject; + +// This is a redefinition of frame state constants: +// https://github.com/python/cpython/blob/3.12/Include/internal/pycore_frame.h#L34 +// The definition is the same for Python 3.11 and 3.12. +// XXX: note that these constants change in 3.13! +typedef enum _framestate { +#if PY_MINOR_VERSION == 13 + FRAME_CREATED = -3, + FRAME_SUSPENDED = -2, + FRAME_SUSPENDED_YIELD_FROM = -1, +#else + FRAME_CREATED = -2, + FRAME_SUSPENDED = -1, +#endif + FRAME_EXECUTING = 0, + FRAME_COMPLETED = 1, + FRAME_CLEARED = 4 +} FrameState; + +// For reference, PyGenObject is defined as follows after expanding top-most macro: +// https://github.com/python/cpython/blob/3.12/Include/cpython/genobject.h +/* +typedef struct { + PyObject_HEAD +#if PY_MINOR_VERSION == 11 + PyCodeObject *gi_code; +#endif + PyObject *gi_weakreflist; + PyObject *gi_name; + PyObject *gi_qualname; + _PyErr_StackItem gi_exc_state; + PyObject *gi_origin_or_finalizer; + char gi_hooks_inited; + char gi_closed; + char gi_running_async; + int8_t gi_frame_state; + PyObject *gi_iframe[1]; +} PyGenObject; +*/ + +static InterpreterFrame *get_interpreter_frame(PyObject *obj) { + struct _PyInterpreterFrame *frame = NULL; + if (PyGen_Check(obj)) { + PyGenObject *gen_obj = (PyGenObject *)obj; + frame = (struct _PyInterpreterFrame *)(gen_obj->gi_iframe); + } else if (PyFrame_Check(obj)) { + PyFrameObject *frame_obj = (PyFrameObject *)obj; + frame = ((FrameObject *)frame_obj)->f_frame; + } else { + PyErr_SetString(PyExc_TypeError, "Object is not a generator or frame"); + return NULL; + } + assert(frame); + return (InterpreterFrame *)frame; +} + +static InterpreterFrame *get_interpreter_frame_from_args(PyObject *args) { + PyObject *obj; + if (!PyArg_ParseTuple(args, "O", &obj)) { + return NULL; + } + return get_interpreter_frame(obj); +} + +static PyGenObject *get_generator_from_args(PyObject *args) { + PyObject *gen_arg; + if (!PyArg_ParseTuple(args, "O", &gen_arg)) { + return NULL; + } + if (!PyGen_Check(gen_arg)) { + PyErr_SetString(PyExc_TypeError, "Input object is not a generator"); + return NULL; + } + return (PyGenObject *)gen_arg; +} + +static PyObject *get_generator_frame_state(PyObject *self, PyObject *args) { + PyGenObject *gen = get_generator_from_args(args); + if (!gen) { + return NULL; + } + return PyLong_FromLong((long)gen->gi_frame_state); +} + +static PyObject *get_frame_ip(PyObject *self, PyObject *args) { + // Note that this method is redundant. You can access the instruction pointer via g.gi_frame.f_lasti. + InterpreterFrame *frame = get_interpreter_frame_from_args(args); + if (!frame) { + return NULL; + } + assert(frame->f_code); + assert(frame->prev_instr); + // See _PyInterpreterFrame_LASTI + // https://github.com/python/cpython/blob/3.12/Include/internal/pycore_frame.h#L77 + intptr_t ip = (intptr_t)frame->prev_instr - (intptr_t)_PyCode_CODE(frame->f_code); + return PyLong_FromLong((long)ip); +} + +static PyObject *get_frame_sp(PyObject *self, PyObject *args) { + InterpreterFrame *frame = get_interpreter_frame_from_args(args); + if (!frame) { + return NULL; + } + assert(frame->stacktop >= 0); + int sp = frame->stacktop; + return PyLong_FromLong((long)sp); +} + +static PyObject *get_frame_stack_at(PyObject *self, PyObject *args) { + PyObject *frame_obj; + int index; + if (!PyArg_ParseTuple(args, "Oi", &frame_obj, &index)) { + return NULL; + } + InterpreterFrame *frame = get_interpreter_frame(frame_obj); + if (!frame) { + return NULL; + } + assert(frame->stacktop >= 0); + + int limit = frame->f_code->co_stacksize + frame->f_code->co_nlocalsplus; + if (index < 0 || index >= limit) { + PyErr_SetString(PyExc_IndexError, "Index out of bounds"); + return NULL; + } + + // NULL in C != None in Python. We need to preserve the fact that some items + // on the stack are NULL (not yet available). + PyObject *is_null = Py_False; + PyObject *obj = frame->localsplus[index]; + if (!obj) { + is_null = Py_True; + obj = Py_None; + } + return PyTuple_Pack(2, is_null, obj); +} + +static PyObject *set_frame_ip(PyObject *self, PyObject *args) { + PyObject *frame_obj; + int ip; + if (!PyArg_ParseTuple(args, "Oi", &frame_obj, &ip)) { + return NULL; + } + InterpreterFrame *frame = get_interpreter_frame(frame_obj); + if (!frame) { + return NULL; + } + assert(frame->f_code); + assert(frame->prev_instr); + // See _PyInterpreterFrame_LASTI + // https://github.com/python/cpython/blob/3.12/Include/internal/pycore_frame.h#L77 + frame->prev_instr = (_Py_CODEUNIT *)((intptr_t)_PyCode_CODE(frame->f_code) + (intptr_t)ip); + Py_RETURN_NONE; +} + +static PyObject *set_frame_sp(PyObject *self, PyObject *args) { + PyObject *frame_obj; + int sp; + if (!PyArg_ParseTuple(args, "Oi", &frame_obj, &sp)) { + return NULL; + } + InterpreterFrame *frame = get_interpreter_frame(frame_obj); + if (!frame) { + return NULL; + } + assert(frame->stacktop >= 0); + + int limit = frame->f_code->co_stacksize + frame->f_code->co_nlocalsplus; + if (sp < 0 || sp >= limit) { + PyErr_SetString(PyExc_IndexError, "Stack pointer out of bounds"); + return NULL; + } + + if (sp > frame->stacktop) { + for (int i = frame->stacktop; i < sp; i++) { + frame->localsplus[i] = NULL; + } + } + + frame->stacktop = sp; + Py_RETURN_NONE; +} + +static PyObject *set_generator_frame_state(PyObject *self, PyObject *args) { + PyObject *gen_arg; + int ip; + if (!PyArg_ParseTuple(args, "Oi", &gen_arg, &ip)) { + return NULL; + } + if (!PyGen_Check(gen_arg)) { + PyErr_SetString(PyExc_TypeError, "Input object is not a generator"); + return NULL; + } + PyGenObject *gen = (PyGenObject *)gen_arg; + // Disallow changing the frame state if the generator is complete + // or has been closed, with the assumption that various parts + // have now been torn down. The generator should be recreated before + // the frame state is changed. + if (gen->gi_frame_state >= FRAME_COMPLETED) { + PyErr_SetString(PyExc_RuntimeError, "Cannot set frame state if generator is complete"); + return NULL; + } + // TODO: check the value is one of the known constants? + gen->gi_frame_state = (int8_t)ip; + Py_RETURN_NONE; +} + +static PyObject *set_frame_stack_at(PyObject *self, PyObject *args) { + PyObject *frame_obj; + int index; + PyObject *unset; + PyObject *obj; + if (!PyArg_ParseTuple(args, "OiOO", &frame_obj, &index, &unset, &obj)) { + return NULL; + } + if (!PyBool_Check(unset)) { + PyErr_SetString(PyExc_TypeError, "Expected a boolean indicating whether to unset the stack object"); + return NULL; + } + InterpreterFrame *frame = get_interpreter_frame(frame_obj); + if (!frame) { + return NULL; + } + assert(frame->stacktop >= 0); + + int limit = frame->f_code->co_stacksize + frame->f_code->co_nlocalsplus; + if (index < 0 || index >= limit) { + PyErr_SetString(PyExc_IndexError, "Index out of bounds"); + return NULL; + } + + PyObject *prev = frame->localsplus[index]; + if (Py_IsTrue(unset)) { + frame->localsplus[index] = NULL; + } else { + Py_INCREF(obj); + frame->localsplus[index] = obj; + } + + if (index < frame->stacktop) { + Py_XDECREF(prev); + } + + Py_RETURN_NONE; +} + +static PyMethodDef methods[] = { + {"get_frame_ip", get_frame_ip, METH_VARARGS, "Get instruction pointer from a frame or generator."}, + {"set_frame_ip", set_frame_ip, METH_VARARGS, "Set instruction pointer in a frame or generator."}, + {"get_frame_sp", get_frame_sp, METH_VARARGS, "Get stack pointer from a frame or generator."}, + {"set_frame_sp", set_frame_sp, METH_VARARGS, "Set stack pointer in a frame or generator."}, + {"get_frame_stack_at", get_frame_stack_at, METH_VARARGS, "Get an object from a frame or generator's stack, as an (is_null, obj) tuple."}, + {"set_frame_stack_at", set_frame_stack_at, METH_VARARGS, "Set or unset an object on the stack of a frame or generator."}, + {"get_generator_frame_state", get_generator_frame_state, METH_VARARGS, "Get frame state from a generator."}, + {"set_generator_frame_state", set_generator_frame_state, METH_VARARGS, "Set frame state of a generator."}, + {NULL, NULL, 0, NULL} +}; + +static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, "_frame", NULL, -1, methods}; + +PyMODINIT_FUNC PyInit__frame(void) { + return PyModule_Create(&module); +} diff --git a/src/dispatch/experimental/durable/_frame.pyi b/src/dispatch/experimental/durable/_frame.pyi new file mode 100644 index 00000000..4bd01e2d --- /dev/null +++ b/src/dispatch/experimental/durable/_frame.pyi @@ -0,0 +1,15 @@ +from types import FrameType +from typing import Any, Tuple, Generator + +def get_frame_ip(frame: FrameType | Generator) -> int: ... +def set_frame_ip(frame: FrameType | Generator, ip: int): ... +def get_frame_sp(frame: FrameType | Generator) -> int: ... +def set_frame_sp(frame: FrameType | Generator, sp: int): ... +def get_frame_stack_at( + frame: FrameType | Generator, index: int +) -> Tuple[bool, Any]: ... +def set_frame_stack_at( + frame: FrameType | Generator, index: int, unset: bool, value: Any +): ... +def get_generator_frame_state(frame: FrameType | Generator) -> int: ... +def set_generator_frame_state(frame: FrameType | Generator, state: int): ... diff --git a/src/dispatch/experimental/durable/durable.py b/src/dispatch/experimental/durable/durable.py new file mode 100644 index 00000000..aaf8049f --- /dev/null +++ b/src/dispatch/experimental/durable/durable.py @@ -0,0 +1,24 @@ +from types import FunctionType, GeneratorType +from .generator import DurableGenerator +from .registry import register_function + + +def durable(fn): + """A decorator for a generator that makes it pickle-able.""" + return DurableFunction(fn) + + +class DurableFunction: + """A durable generator function that can be pickled.""" + + def __init__(self, fn: FunctionType): + self.fn = fn + self.key = register_function(fn) + + def __call__(self, *args, **kwargs): + result = self.fn(*args, **kwargs) + if isinstance(result, GeneratorType): + return DurableGenerator(result, self.key, args, kwargs) + + # TODO: support native coroutines + raise NotImplementedError diff --git a/src/dispatch/experimental/durable/generator.py b/src/dispatch/experimental/durable/generator.py new file mode 100644 index 00000000..326d07ed --- /dev/null +++ b/src/dispatch/experimental/durable/generator.py @@ -0,0 +1,106 @@ +from types import GeneratorType, TracebackType, CodeType, FrameType +from typing import Generator, TypeVar +from .registry import lookup_function +from . import _frame as ext + + +_YieldT = TypeVar("_YieldT", covariant=True) +_SendT = TypeVar("_SendT", contravariant=True) +_ReturnT = TypeVar("_ReturnT", covariant=True) + + +class DurableGenerator(Generator[_YieldT, _SendT, _ReturnT]): + """A generator that can be pickled.""" + + def __init__(self, gen: GeneratorType, key, args, kwargs): + self.generator = gen + + # Capture the information necessary to be able to create a + # new instance of the generator. + self.key = key + self.args = args + self.kwargs = kwargs + + def __iter__(self) -> Generator[_YieldT, _SendT, _ReturnT]: + return self + + def __next__(self) -> _YieldT: + return next(self.generator) + + def send(self, send: _SendT) -> _YieldT: + return self.generator.send(send) + + def throw(self, typ, val=None, tb: TracebackType | None = None) -> _YieldT: + return self.generator.throw(typ, val, tb) + + def close(self) -> None: + self.generator.close() + + @property + def gi_running(self) -> bool: + return self.generator.gi_running + + @property + def gi_suspended(self) -> bool: + return self.generator.gi_suspended + + @property + def gi_code(self) -> CodeType: + return self.generator.gi_code + + @property + def gi_frame(self) -> FrameType: + return self.generator.gi_frame + + @property + def gi_yieldfrom(self) -> GeneratorType | None: + return self.generator.gi_yieldfrom + + def __getstate__(self): + # Capture the details necessary to recreate the generator. + frame = self.generator.gi_frame + state = { + "function": { + "key": self.key, + "args": self.args, + "kwargs": self.kwargs, + }, + "generator": { + "frame_state": ext.get_generator_frame_state(self.generator), + }, + "frame": { + "ip": ext.get_frame_ip(frame), # aka. frame.f_lasti + "sp": ext.get_frame_sp(frame), + "stack": [ + ext.get_frame_stack_at(frame, i) + for i in range(ext.get_frame_sp(frame)) + ], + }, + } + return state + + def __setstate__(self, state): + function_state = state["function"] + generator_state = state["generator"] + frame_state = state["frame"] + + # Recreate the generator by looking up the constructor + # and calling it with the same args/kwargs. + self.key, self.args, self.kwargs = ( + function_state["key"], + function_state["args"], + function_state["kwargs"], + ) + generator_fn = lookup_function(self.key) + self.generator = generator_fn(*self.args, **self.kwargs) + + # Restore the frame state (stack + stack pointer + instruction pointer). + frame = self.generator.gi_frame + ext.set_frame_ip(frame, frame_state["ip"]) + ext.set_frame_sp(frame, frame_state["sp"]) + for i, (is_null, obj) in enumerate(frame_state["stack"]): + ext.set_frame_stack_at(frame, i, is_null, obj) + + # Restore the generator state (the frame state field tracks whether the + # frame is newly created, or whether it was previously suspended). + ext.set_generator_frame_state(self.generator, generator_state["frame_state"]) diff --git a/src/dispatch/experimental/durable/registry.py b/src/dispatch/experimental/durable/registry.py new file mode 100644 index 00000000..a0f70206 --- /dev/null +++ b/src/dispatch/experimental/durable/registry.py @@ -0,0 +1,24 @@ +from types import FunctionType + + +_REGISTRY: dict[str, FunctionType] = {} + + +def register_function(fn: FunctionType) -> str: + # We need to be able to refer to the function in the serialized + # representation, and the key needs to be stable across interpreter + # invocations. Use the code object's fully-qualified name for now. + # If there are name clashes, the location of the function + # (co_filename + co_firstlineno) and/or a hash of the bytecode + # (co_code) could be used as well or instead. + key = fn.__code__.co_qualname + if key in _REGISTRY: + raise ValueError(f"durable function already registered with key {key}") + + _REGISTRY[key] = fn + + return key + + +def lookup_function(key: str) -> FunctionType: + return _REGISTRY[key] diff --git a/tests/dispatch/experimental/durable/__init__.py b/tests/dispatch/experimental/durable/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dispatch/experimental/durable/test_frame.py b/tests/dispatch/experimental/durable/test_frame.py new file mode 100644 index 00000000..11c37ddc --- /dev/null +++ b/tests/dispatch/experimental/durable/test_frame.py @@ -0,0 +1,34 @@ +import unittest +from dispatch.experimental.durable import _frame as ext + + +def generator(a): + yield a + a += 1 + yield a + a += 1 + yield a + + +class TestFrame(unittest.TestCase): + def test_copy(self): + # Create an instance and run it to the first yield point. + g = generator(1) + assert next(g) == 1 + + # Copy the generator. + g2 = generator(1) + ext.set_generator_frame_state(g2, ext.get_generator_frame_state(g)) + ext.set_frame_ip(g2, ext.get_frame_ip(g)) + ext.set_frame_sp(g2, ext.get_frame_sp(g)) + for i in range(ext.get_frame_sp(g)): + is_null, obj = ext.get_frame_stack_at(g, i) + ext.set_frame_stack_at(g2, i, is_null, obj) + + # The copy should start from where the previous generator was suspended. + assert next(g2) == 2 + assert next(g2) == 3 + + # Original generator is not affected. + assert next(g) == 2 + assert next(g) == 3 diff --git a/tests/dispatch/experimental/durable/test_generator.py b/tests/dispatch/experimental/durable/test_generator.py new file mode 100644 index 00000000..c8d3d35f --- /dev/null +++ b/tests/dispatch/experimental/durable/test_generator.py @@ -0,0 +1,59 @@ +import unittest +import pickle +from dispatch.experimental.durable import durable + + +@durable +def durable_generator(a): + yield a + a += 1 + yield a + a += 1 + yield a + + +@durable +def nested_generators(start): + yield from durable_generator(start) + yield from durable_generator(start + 3) + + +class TestGenerator(unittest.TestCase): + def test_pickle(self): + # Create an instance and run it to the first yield point. + g = durable_generator(1) + assert next(g) == 1 + + # Copy the generator by serializing the DurableGenerator instance to bytes + # and back. + state = pickle.dumps(g) + g2 = pickle.loads(state) + + # The copy should start from where the previous generator was suspended. + assert next(g2) == 2 + assert next(g2) == 3 + + # The original generator is not affected. + assert next(g) == 2 + assert next(g) == 3 + + def test_nested(self): + expect = [1, 2, 3, 4, 5, 6] + assert list(nested_generators(1)) == expect + + # Check that the generator can be pickled at every yield point. + for i in range(len(expect)): + # Create a generator and advance to the i'th yield point. + g = nested_generators(1) + for j in range(i): + assert next(g) == expect[j] + + # Create a copy of the generator. + state = pickle.dumps(g) + g2 = pickle.loads(state) + + # Check that both the original and the copy yield the + # remaining expected values. + for j in range(i, len(expect)): + assert next(g) == expect[j] + assert next(g2) == expect[j] From 793fd706e732d79beb919a36f6f0e8a0a053b1f5 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Wed, 31 Jan 2024 09:34:05 -0500 Subject: [PATCH 3/8] Add setup.py to compile the C extension --- setup.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..a341d680 --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +from setuptools import Extension, setup + +setup( + ext_modules=[ + Extension( + name="dispatch.experimental.durable._frame", + sources=["src/dispatch/experimental/durable/_frame.c"], + ), + ] +) From 9662ddbd34ad5d0cf85558ef15774dee3c745672 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 1 Feb 2024 09:34:47 +1000 Subject: [PATCH 4/8] Remove underscore prefix from C extension module name --- setup.py | 4 ++-- src/dispatch/experimental/durable/{_frame.c => frame.c} | 4 ++-- src/dispatch/experimental/durable/{_frame.pyi => frame.pyi} | 0 src/dispatch/experimental/durable/generator.py | 2 +- tests/dispatch/experimental/durable/test_frame.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) rename src/dispatch/experimental/durable/{_frame.c => frame.c} (98%) rename src/dispatch/experimental/durable/{_frame.pyi => frame.pyi} (100%) diff --git a/setup.py b/setup.py index a341d680..44212f1a 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,8 @@ setup( ext_modules=[ Extension( - name="dispatch.experimental.durable._frame", - sources=["src/dispatch/experimental/durable/_frame.c"], + name="dispatch.experimental.durable.frame", + sources=["src/dispatch/experimental/durable/frame.c"], ), ] ) diff --git a/src/dispatch/experimental/durable/_frame.c b/src/dispatch/experimental/durable/frame.c similarity index 98% rename from src/dispatch/experimental/durable/_frame.c rename to src/dispatch/experimental/durable/frame.c index 22c47cac..58566519 100644 --- a/src/dispatch/experimental/durable/_frame.c +++ b/src/dispatch/experimental/durable/frame.c @@ -313,8 +313,8 @@ static PyMethodDef methods[] = { {NULL, NULL, 0, NULL} }; -static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, "_frame", NULL, -1, methods}; +static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, "frame", NULL, -1, methods}; -PyMODINIT_FUNC PyInit__frame(void) { +PyMODINIT_FUNC PyInit_frame(void) { return PyModule_Create(&module); } diff --git a/src/dispatch/experimental/durable/_frame.pyi b/src/dispatch/experimental/durable/frame.pyi similarity index 100% rename from src/dispatch/experimental/durable/_frame.pyi rename to src/dispatch/experimental/durable/frame.pyi diff --git a/src/dispatch/experimental/durable/generator.py b/src/dispatch/experimental/durable/generator.py index 326d07ed..eef51b6d 100644 --- a/src/dispatch/experimental/durable/generator.py +++ b/src/dispatch/experimental/durable/generator.py @@ -1,7 +1,7 @@ from types import GeneratorType, TracebackType, CodeType, FrameType from typing import Generator, TypeVar from .registry import lookup_function -from . import _frame as ext +from . import frame as ext _YieldT = TypeVar("_YieldT", covariant=True) diff --git a/tests/dispatch/experimental/durable/test_frame.py b/tests/dispatch/experimental/durable/test_frame.py index 11c37ddc..958f98ab 100644 --- a/tests/dispatch/experimental/durable/test_frame.py +++ b/tests/dispatch/experimental/durable/test_frame.py @@ -1,5 +1,5 @@ import unittest -from dispatch.experimental.durable import _frame as ext +from dispatch.experimental.durable import frame as ext def generator(a): From 0828ba7dac97fcf380453df4f6b5f685aa951f8b Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 1 Feb 2024 09:39:51 +1000 Subject: [PATCH 5/8] Avoid allocating frame object when serializing durable generator --- src/dispatch/experimental/durable/generator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/dispatch/experimental/durable/generator.py b/src/dispatch/experimental/durable/generator.py index eef51b6d..9caf7945 100644 --- a/src/dispatch/experimental/durable/generator.py +++ b/src/dispatch/experimental/durable/generator.py @@ -58,7 +58,7 @@ def gi_yieldfrom(self) -> GeneratorType | None: def __getstate__(self): # Capture the details necessary to recreate the generator. - frame = self.generator.gi_frame + g = self.generator state = { "function": { "key": self.key, @@ -66,14 +66,13 @@ def __getstate__(self): "kwargs": self.kwargs, }, "generator": { - "frame_state": ext.get_generator_frame_state(self.generator), + "frame_state": ext.get_generator_frame_state(g), }, "frame": { - "ip": ext.get_frame_ip(frame), # aka. frame.f_lasti - "sp": ext.get_frame_sp(frame), + "ip": ext.get_frame_ip(g), + "sp": ext.get_frame_sp(g), "stack": [ - ext.get_frame_stack_at(frame, i) - for i in range(ext.get_frame_sp(frame)) + ext.get_frame_stack_at(g, i) for i in range(ext.get_frame_sp(g)) ], }, } From 30277d9300a2469059894192ea794746b5f6be78 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 1 Feb 2024 10:30:49 +1000 Subject: [PATCH 6/8] Support bound methods --- .../experimental/multicolor/compile.py | 40 +++++++++++++------ .../experimental/multicolor/test_compile.py | 29 +++++++++++++- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/dispatch/experimental/multicolor/compile.py b/src/dispatch/experimental/multicolor/compile.py index ea8e8369..e0244abd 100644 --- a/src/dispatch/experimental/multicolor/compile.py +++ b/src/dispatch/experimental/multicolor/compile.py @@ -1,8 +1,9 @@ import ast import inspect import os +import types from enum import Enum -from types import FunctionType, GeneratorType +from types import FunctionType, GeneratorType, MethodType from typing import cast from .desugar import desugar_function from .generator import is_generator, empty_generator @@ -15,7 +16,7 @@ def compile_function( fn: FunctionType, decorator: FunctionType | None = None, cache_key: str = "default" -) -> FunctionType: +) -> FunctionType | MethodType: """Compile a regular function into a generator that yields data passed to functions marked with the @multicolor.yields decorator. Decorated functions can be called from anywhere in the call stack, and functions @@ -87,16 +88,26 @@ class FunctionColor(Enum): def compile_internal( fn: FunctionType, decorator: FunctionType | None, cache_key: str -) -> tuple[FunctionType, FunctionColor]: +) -> tuple[FunctionType | MethodType, FunctionColor]: if hasattr(fn, "_multicolor_yield_type"): raise ValueError("cannot compile a yield point directly") + # Give the function a unique name. + fn_name = fn.__name__ + "__multicolor_" + cache_key + # Check if the function has already been compiled. - if hasattr(fn, "_multicolor_cache"): + cache_holder = fn + if isinstance(fn, MethodType): + cache_holder = fn.__self__ + if hasattr(cache_holder, "_multicolor_cache"): try: - return fn._multicolor_cache[cache_key] + compiled_fn, color = cache_holder._multicolor_cache[fn_name] except KeyError: pass + else: + if isinstance(fn, MethodType): + return MethodType(compiled_fn, fn.__self__), color + return compiled_fn, color # Parse an abstract syntax tree from the function source. try: @@ -120,6 +131,8 @@ def compile_internal( print("[MULTICOLOR] COMPILING:") print(repair_indentation(inspect.getsource(fn)).rstrip()) + fn_def.name = fn_name + # De-sugar the AST to simplify subsequent transformations. desugar_function(fn_def) @@ -143,9 +156,6 @@ def compile_internal( g = ast.Call(func=empty, args=[], keywords=[]) fn_def.body.insert(0, ast.Expr(ast.YieldFrom(value=g))) - name = fn_def.name + "__multicolor_" + cache_key - fn_def.name = name - # Patch AST nodes that were inserted without location info. ast.fix_missing_locations(root) @@ -169,21 +179,25 @@ def compile_internal( # Re-compile. code = compile(root, filename="", mode="exec") exec(code, namespace) - compiled_fn = namespace[name] + compiled_fn = namespace[fn_name] # Apply the custom decorator, if applicable. if decorator is not None: compiled_fn = decorator(compiled_fn) # Cache the compiled function. - if hasattr(fn, "_multicolor_cache"): + if hasattr(cache_holder, "_multicolor_cache"): cache = cast( - dict[str, tuple[FunctionType, FunctionColor]], fn._multicolor_cache + dict[str, tuple[FunctionType, FunctionColor]], + cache_holder._multicolor_cache, ) else: cache = {} - setattr(fn, "_multicolor_cache", cache) - cache[cache_key] = (compiled_fn, color) + setattr(cache_holder, "_multicolor_cache", cache) + cache[fn_name] = (compiled_fn, color) + + if isinstance(fn, MethodType): + return MethodType(compiled_fn, fn.__self__), color return compiled_fn, color diff --git a/tests/dispatch/experimental/multicolor/test_compile.py b/tests/dispatch/experimental/multicolor/test_compile.py index cbfa86c6..642bc941 100644 --- a/tests/dispatch/experimental/multicolor/test_compile.py +++ b/tests/dispatch/experimental/multicolor/test_compile.py @@ -14,7 +14,8 @@ class YieldTypes(Enum): SLEEP = 0 ADD = 1 - STAR_ARGS_KWARGS = 2 + MUL = 2 + STAR_ARGS_KWARGS = 3 @yields(type=YieldTypes.SLEEP) @@ -197,6 +198,32 @@ def fn(): ] self.assert_yields(fn, yields=yields, returns=3) + def test_class_method(self): + class Foo: + def sleep_then_fma(self, m, a, b): + sleep(100) + return self.mul(m, self.add_indirect(a, b)) + + @yields(type=YieldTypes.MUL) + def mul(self): + raise RuntimeError("implementation is provided elsewhere") + + def add_indirect(self, a, b): + return add(a, b) + + foo = Foo() + self.assert_yields( + foo.sleep_then_fma, + args=[10, 1, 2], + yields=[ + CustomYield(type=YieldTypes.SLEEP, args=[100]), + CustomYield(type=YieldTypes.ADD, args=[1, 2]), + CustomYield(type=YieldTypes.MUL, args=[10, 3]), + ], + sends=[None, 3, 30], + returns=30, + ) + def assert_yields( self, fn: FunctionType, From 396d0ee2b78f6c4ea26f7ec4737aeafa869a12c5 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 1 Feb 2024 10:41:09 +1000 Subject: [PATCH 7/8] Highlight the eager evaluation issue --- .../experimental/multicolor/test_compile.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/dispatch/experimental/multicolor/test_compile.py b/tests/dispatch/experimental/multicolor/test_compile.py index 642bc941..d867fd23 100644 --- a/tests/dispatch/experimental/multicolor/test_compile.py +++ b/tests/dispatch/experimental/multicolor/test_compile.py @@ -224,6 +224,33 @@ def add_indirect(self, a, b): returns=30, ) + def test_generator_evaluation(self): + self.skipTest( + "highlight how eager evaluation of generators can change the program" + ) + + def generator(n): + for i in range(n): + sleep(i) + yield i + + def zipper(g, n): + return list(zip(g(n), g(n))) + + # The generators are evaluated at their call site, which means + # [0, 1, 2, 0, 1, 2] is observed rather than [0, 0, 1, 1, 2, 2]. + yields = [ + CustomYield(type=YieldTypes.SLEEP, args=[0]), + CustomYield(type=YieldTypes.SLEEP, args=[1]), + CustomYield(type=YieldTypes.SLEEP, args=[2]), + CustomYield(type=YieldTypes.SLEEP, args=[0]), + CustomYield(type=YieldTypes.SLEEP, args=[1]), + CustomYield(type=YieldTypes.SLEEP, args=[2]), + ] + self.assert_yields( + zipper, args=[generator, 3], yields=yields, returns=[(0, 0), (1, 1), (2, 2)] + ) + def assert_yields( self, fn: FunctionType, From 9a68fe13ac81f46bd9bf44b611b8545d8bf9e249 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 1 Feb 2024 11:00:45 +1000 Subject: [PATCH 8/8] Document the experimental packages --- src/dispatch/experimental/durable/README.md | 6 ++ .../experimental/multicolor/README.md | 68 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 src/dispatch/experimental/durable/README.md create mode 100644 src/dispatch/experimental/multicolor/README.md diff --git a/src/dispatch/experimental/durable/README.md b/src/dispatch/experimental/durable/README.md new file mode 100644 index 00000000..cd7d8cf9 --- /dev/null +++ b/src/dispatch/experimental/durable/README.md @@ -0,0 +1,6 @@ +This package defines a `@durable` decorator that can be +applied to generator functions. The decorated generators +can be [pickled][pickle]. + + +[pickle]: https://docs.python.org/3/library/pickle.html diff --git a/src/dispatch/experimental/multicolor/README.md b/src/dispatch/experimental/multicolor/README.md new file mode 100644 index 00000000..3657fdbf --- /dev/null +++ b/src/dispatch/experimental/multicolor/README.md @@ -0,0 +1,68 @@ +This package contains a JIT compiler that "recolors" functions on the fly. + +[What color is your function?][what-color] Python has async functions (red), generator functions (green), +async generator functions (yellow) and regular functions (blue): + +```python +>>> async def red(): pass +>>> def green(): yield +>>> async def yellow(): yield +>>> def blue(): pass +``` + +You interact with these functions in different ways. For example, you `await` red and yellow async +functions, and `yield from` (or iterate over) green and yellow generator functions. + +There are rules that make mixing colors painful. For example, you cannot `await` an async red or +yellow function from a non-async blue or green function. Some colors (e.g. red, yellow) tend to +infect a codebase, requiring that you either avoid that color or [go all in][asyncio]. + +Red, green and yellow functions create `coroutine`, `generator` and `async_generator` objects, +respectively: + +``` +>>> red() + +>>> green() + +>>> yellow() + +``` + +These objects are all types of coroutines. They all share a desirable property; they can +be suspended during execution and then later resumed from the same point. There is however +a major caveat, which is that to suspend a coroutine deep within a call stack, there cannot +be regular (blue) function call on the path. Unfortunately, most Python functions in the +standard library and [package index][pypi] are blue. + +`multicolor` solves the issue by providing a `compile_function` that turns blue functions +green. As blue or green functions are called, it recursively turns them green. This turns +regular functions and generator functions into coroutines that can be suspended at any +point, even when there are functions from the standard library or other dependencies that +the user has no control over. + +```python +from multicolor import compile_function + +green = compile_function(blue) # recursively turns functions into green coroutines +``` + +See the internal `compile_function` docs for more usage information. + +Caveats: +* Only functions called explicitly are recolored. Implicit function calls (e.g. via + magic methods) are not supported at this time. +* Red and yellow function support may be added in future, allowing the user to mix and + match all colors. For now, this package only works with synchronous (blue and green) + functions. +* Function calls are not currently supported in `match` case guards, in the parameter + list of a nested function or class definition, in `lambda` functions and as exception + handler type expressions. +* Nested generators are supported, but they're eagerly evaluated at their call site + which may subtly break your program. Nested `yield from` statements are not well + supported. + + +[what-color]: https://journal.stuffwithstuff.com/2015/02/01/what-color-is-your-function/ +[asyncio]: https://docs.python.org/3/library/asyncio.html +[pypi]: https://pypi.org