Skip to content

Commit

Permalink
Test and document the template helper
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Feb 3, 2024
1 parent 2f3ad50 commit 1753983
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 46 deletions.
58 changes: 28 additions & 30 deletions src/dispatch/experimental/multicolor/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,38 +270,36 @@ def _build_call_gadget(

result = rewrite_template(
"""
if hasattr(__fn__, "_multicolor_yield_type"):
_multicolor_result = yield _multicolor_custom_yield(type=__fn__._multicolor_yield_type, args=__args__, kwargs=__kwargs__)
__assign_result__
else:
_multicolor_result = None
try:
__compiled_fn__, _multicolor_color = _multicolor_compile(__fn__, _multicolor_decorator, _multicolor_cache_key)
except _multicolor_no_source_error:
_multicolor_result = __fn_call__
else:
_multicolor_generator = __compiled_fn_call__
if _multicolor_color == _multicolor_generator_color:
_multicolor_result = []
for _multicolor_yield in _multicolor_generator:
if isinstance(_multicolor_yield, _multicolor_generator_yield):
_multicolor_result.append(_multicolor_yield.value)
if hasattr(__fn__, "_multicolor_yield_type"):
_multicolor_result = yield _multicolor_custom_yield(type=__fn__._multicolor_yield_type, args=__args__, kwargs=__kwargs__)
__assign_result__
else:
_multicolor_result = None
try:
__compiled_fn__, _multicolor_color = _multicolor_compile(__fn__, _multicolor_decorator, _multicolor_cache_key)
except _multicolor_no_source_error:
_multicolor_result = __fn_call__
else:
yield _multicolor_yield
else:
_multicolor_result = yield from _multicolor_generator
finally:
__assign_result__
_multicolor_generator = __compiled_fn_call__
if _multicolor_color == _multicolor_generator_color:
_multicolor_result = []
for _multicolor_yield in _multicolor_generator:
if isinstance(_multicolor_yield, _multicolor_generator_yield):
_multicolor_result.append(_multicolor_yield.value)
else:
yield _multicolor_yield
else:
_multicolor_result = yield from _multicolor_generator
finally:
__assign_result__
""",
expressions=dict(
__fn__=fn,
__fn_call__=fn_call,
__args__=args,
__kwargs__=kwargs,
__compiled_fn__=compiled_fn,
__compiled_fn_call__=compiled_fn_call,
),
statements=dict(__assign_result__=assign_result),
__fn__=fn,
__fn_call__=fn_call,
__args__=args,
__kwargs__=kwargs,
__compiled_fn__=compiled_fn,
__compiled_fn_call__=compiled_fn_call,
__assign_result__=assign_result,
)

return result[0]
41 changes: 28 additions & 13 deletions src/dispatch/experimental/multicolor/template.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,50 @@
import ast
import textwrap


def rewrite_template(
template: str, expressions: dict[str, ast.expr], statements: dict[str, ast.stmt]
template: str, **replacements: ast.expr | ast.stmt
) -> list[ast.stmt]:
"""Create an AST by parsing a template string and then replacing
temporary variables with the specified AST nodes."""
root = ast.parse(template)
root = NameTransformer(expressions=expressions, statements=statements).visit(root)
embedded identifiers with the provided AST nodes.
Args:
template: String containing source code (one or more statements).
**replacements: Dictionary mapping identifiers to replacement nodes.
Returns:
list[ast.stmt]: List of AST statements.
"""
root = ast.parse(textwrap.dedent(template))
root = NameTransformer(**replacements).visit(root)
return root.body


class NameTransformer(ast.NodeTransformer):
"""Replace ast.Name nodes in an AST."""

def __init__(
self, expressions: dict[str, ast.expr], statements: dict[str, ast.stmt]
):
self.expressions = expressions
self.statements = statements
exprs: dict[str, ast.expr]
stmts: dict[str, ast.stmt]

def __init__(self, **replacements: ast.expr | ast.stmt):
self.exprs = {}
self.stmts = {}
for key, node in replacements.items():
if isinstance(node, ast.expr):
self.exprs[key] = node
elif isinstance(node, ast.stmt):
self.stmts[key] = node

def visit_Name(self, node):
def visit_Name(self, node: ast.Name) -> ast.expr:
try:
return self.expressions[node.id]
return self.exprs[node.id]
except KeyError:
return node

def visit_Expr(self, node):
def visit_Expr(self, node: ast.Expr) -> ast.stmt:
if not isinstance(node.value, ast.Name):
return node
try:
return self.statements[node.value.id]
return self.stmts[node.value.id]
except KeyError:
return node
20 changes: 17 additions & 3 deletions src/dispatch/experimental/multicolor/yields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@


def yields(type: Any):
"""Mark a function as a custom yield point."""
"""Returns a decorator that marks functions as a type of yield.
Args:
type: Opaque type for this yield.
"""

def decorator(fn: FunctionType) -> FunctionType:
fn._multicolor_yield_type = type # type: ignore[attr-defined]
Expand All @@ -19,7 +23,13 @@ class YieldType:

@dataclass
class CustomYield(YieldType):
"""A yield from a function marked with @yields."""
"""A yield from a function marked with @yields.
Attributes:
type: The type of yield that was specified in the @yields decorator.
args: Positional arguments to the function call.
kwargs: Keyword arguments to the function call.
"""

type: Any
args: list[Any]
Expand All @@ -28,6 +38,10 @@ class CustomYield(YieldType):

@dataclass
class GeneratorYield(YieldType):
"""A yield from a generator."""
"""A yield from a generator.
Attributes:
value: The value that was yielded from the generator.
"""

value: Any = None
31 changes: 31 additions & 0 deletions tests/dispatch/experimental/multicolor/test_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import textwrap
import unittest
import ast
from typing import cast
from dispatch.experimental.multicolor.template import rewrite_template


class TestTemplate(unittest.TestCase):
def test_rewrite_template(self):
self.assert_rewrite(
"""
a
b = c
""",
dict(
a=ast.Expr(ast.Name(id="d", ctx=ast.Load())),
c=ast.Name(id="e", ctx=ast.Load()),
),
"""
d
b = e
""",
)

def assert_rewrite(
self, template: str, replacements: dict[str, ast.expr | ast.stmt], want: str
):
result = rewrite_template(template, **replacements)
self.assertEqual(
ast.unparse(cast(ast.AST, result)), textwrap.dedent(want).strip()
)

0 comments on commit 1753983

Please sign in to comment.