diff --git a/src/dispatch/experimental/multicolor/compile.py b/src/dispatch/experimental/multicolor/compile.py index bbffea55..ceda725e 100644 --- a/src/dispatch/experimental/multicolor/compile.py +++ b/src/dispatch/experimental/multicolor/compile.py @@ -270,38 +270,36 @@ def _build_call_gadget( result = rewrite_template( """ -if hasattr(__fn__, "_multicolor_yield_type"): - _multicolor_result = yield _multicolor_custom_yield(type=__fn__._multicolor_yield_type, args=__args__, kwargs=__kwargs__) - __assign_result__ -else: - _multicolor_result = None - try: - __compiled_fn__, _multicolor_color = _multicolor_compile(__fn__, _multicolor_decorator, _multicolor_cache_key) - except _multicolor_no_source_error: - _multicolor_result = __fn_call__ - else: - _multicolor_generator = __compiled_fn_call__ - if _multicolor_color == _multicolor_generator_color: - _multicolor_result = [] - for _multicolor_yield in _multicolor_generator: - if isinstance(_multicolor_yield, _multicolor_generator_yield): - _multicolor_result.append(_multicolor_yield.value) + if hasattr(__fn__, "_multicolor_yield_type"): + _multicolor_result = yield _multicolor_custom_yield(type=__fn__._multicolor_yield_type, args=__args__, kwargs=__kwargs__) + __assign_result__ + else: + _multicolor_result = None + try: + __compiled_fn__, _multicolor_color = _multicolor_compile(__fn__, _multicolor_decorator, _multicolor_cache_key) + except _multicolor_no_source_error: + _multicolor_result = __fn_call__ else: - yield _multicolor_yield - else: - _multicolor_result = yield from _multicolor_generator - finally: - __assign_result__ + _multicolor_generator = __compiled_fn_call__ + if _multicolor_color == _multicolor_generator_color: + _multicolor_result = [] + for _multicolor_yield in _multicolor_generator: + if isinstance(_multicolor_yield, _multicolor_generator_yield): + _multicolor_result.append(_multicolor_yield.value) + else: + yield _multicolor_yield + else: + _multicolor_result = yield from _multicolor_generator + finally: + __assign_result__ """, - expressions=dict( - __fn__=fn, - __fn_call__=fn_call, - __args__=args, - __kwargs__=kwargs, - __compiled_fn__=compiled_fn, - __compiled_fn_call__=compiled_fn_call, - ), - statements=dict(__assign_result__=assign_result), + __fn__=fn, + __fn_call__=fn_call, + __args__=args, + __kwargs__=kwargs, + __compiled_fn__=compiled_fn, + __compiled_fn_call__=compiled_fn_call, + __assign_result__=assign_result, ) return result[0] diff --git a/src/dispatch/experimental/multicolor/template.py b/src/dispatch/experimental/multicolor/template.py index 190a792c..bc68eece 100644 --- a/src/dispatch/experimental/multicolor/template.py +++ b/src/dispatch/experimental/multicolor/template.py @@ -1,35 +1,50 @@ import ast +import textwrap def rewrite_template( - template: str, expressions: dict[str, ast.expr], statements: dict[str, ast.stmt] + template: str, **replacements: ast.expr | ast.stmt ) -> list[ast.stmt]: """Create an AST by parsing a template string and then replacing - temporary variables with the specified AST nodes.""" - root = ast.parse(template) - root = NameTransformer(expressions=expressions, statements=statements).visit(root) + embedded identifiers with the provided AST nodes. + + Args: + template: String containing source code (one or more statements). + **replacements: Dictionary mapping identifiers to replacement nodes. + + Returns: + list[ast.stmt]: List of AST statements. + """ + root = ast.parse(textwrap.dedent(template)) + root = NameTransformer(**replacements).visit(root) return root.body class NameTransformer(ast.NodeTransformer): """Replace ast.Name nodes in an AST.""" - def __init__( - self, expressions: dict[str, ast.expr], statements: dict[str, ast.stmt] - ): - self.expressions = expressions - self.statements = statements + exprs: dict[str, ast.expr] + stmts: dict[str, ast.stmt] + + def __init__(self, **replacements: ast.expr | ast.stmt): + self.exprs = {} + self.stmts = {} + for key, node in replacements.items(): + if isinstance(node, ast.expr): + self.exprs[key] = node + elif isinstance(node, ast.stmt): + self.stmts[key] = node - def visit_Name(self, node): + def visit_Name(self, node: ast.Name) -> ast.expr: try: - return self.expressions[node.id] + return self.exprs[node.id] except KeyError: return node - def visit_Expr(self, node): + def visit_Expr(self, node: ast.Expr) -> ast.stmt: if not isinstance(node.value, ast.Name): return node try: - return self.statements[node.value.id] + return self.stmts[node.value.id] except KeyError: return node diff --git a/src/dispatch/experimental/multicolor/yields.py b/src/dispatch/experimental/multicolor/yields.py index ce5fb505..73c8970f 100644 --- a/src/dispatch/experimental/multicolor/yields.py +++ b/src/dispatch/experimental/multicolor/yields.py @@ -4,7 +4,11 @@ def yields(type: Any): - """Mark a function as a custom yield point.""" + """Returns a decorator that marks functions as a type of yield. + + Args: + type: Opaque type for this yield. + """ def decorator(fn: FunctionType) -> FunctionType: fn._multicolor_yield_type = type # type: ignore[attr-defined] @@ -19,7 +23,13 @@ class YieldType: @dataclass class CustomYield(YieldType): - """A yield from a function marked with @yields.""" + """A yield from a function marked with @yields. + + Attributes: + type: The type of yield that was specified in the @yields decorator. + args: Positional arguments to the function call. + kwargs: Keyword arguments to the function call. + """ type: Any args: list[Any] @@ -28,6 +38,10 @@ class CustomYield(YieldType): @dataclass class GeneratorYield(YieldType): - """A yield from a generator.""" + """A yield from a generator. + + Attributes: + value: The value that was yielded from the generator. + """ value: Any = None diff --git a/tests/dispatch/experimental/multicolor/test_template.py b/tests/dispatch/experimental/multicolor/test_template.py new file mode 100644 index 00000000..fc390221 --- /dev/null +++ b/tests/dispatch/experimental/multicolor/test_template.py @@ -0,0 +1,31 @@ +import textwrap +import unittest +import ast +from typing import cast +from dispatch.experimental.multicolor.template import rewrite_template + + +class TestTemplate(unittest.TestCase): + def test_rewrite_template(self): + self.assert_rewrite( + """ + a + b = c + """, + dict( + a=ast.Expr(ast.Name(id="d", ctx=ast.Load())), + c=ast.Name(id="e", ctx=ast.Load()), + ), + """ + d + b = e + """, + ) + + def assert_rewrite( + self, template: str, replacements: dict[str, ast.expr | ast.stmt], want: str + ): + result = rewrite_template(template, **replacements) + self.assertEqual( + ast.unparse(cast(ast.AST, result)), textwrap.dedent(want).strip() + )