-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test and document the template helper
- Loading branch information
Showing
4 changed files
with
104 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,50 @@ | ||
import ast | ||
import textwrap | ||
|
||
|
||
def rewrite_template( | ||
template: str, expressions: dict[str, ast.expr], statements: dict[str, ast.stmt] | ||
template: str, **replacements: ast.expr | ast.stmt | ||
) -> list[ast.stmt]: | ||
"""Create an AST by parsing a template string and then replacing | ||
temporary variables with the specified AST nodes.""" | ||
root = ast.parse(template) | ||
root = NameTransformer(expressions=expressions, statements=statements).visit(root) | ||
embedded identifiers with the provided AST nodes. | ||
Args: | ||
template: String containing source code (one or more statements). | ||
**replacements: Dictionary mapping identifiers to replacement nodes. | ||
Returns: | ||
list[ast.stmt]: List of AST statements. | ||
""" | ||
root = ast.parse(textwrap.dedent(template)) | ||
root = NameTransformer(**replacements).visit(root) | ||
return root.body | ||
|
||
|
||
class NameTransformer(ast.NodeTransformer): | ||
"""Replace ast.Name nodes in an AST.""" | ||
|
||
def __init__( | ||
self, expressions: dict[str, ast.expr], statements: dict[str, ast.stmt] | ||
): | ||
self.expressions = expressions | ||
self.statements = statements | ||
exprs: dict[str, ast.expr] | ||
stmts: dict[str, ast.stmt] | ||
|
||
def __init__(self, **replacements: ast.expr | ast.stmt): | ||
self.exprs = {} | ||
self.stmts = {} | ||
for key, node in replacements.items(): | ||
if isinstance(node, ast.expr): | ||
self.exprs[key] = node | ||
elif isinstance(node, ast.stmt): | ||
self.stmts[key] = node | ||
|
||
def visit_Name(self, node): | ||
def visit_Name(self, node: ast.Name) -> ast.expr: | ||
try: | ||
return self.expressions[node.id] | ||
return self.exprs[node.id] | ||
except KeyError: | ||
return node | ||
|
||
def visit_Expr(self, node): | ||
def visit_Expr(self, node: ast.Expr) -> ast.stmt: | ||
if not isinstance(node.value, ast.Name): | ||
return node | ||
try: | ||
return self.statements[node.value.id] | ||
return self.stmts[node.value.id] | ||
except KeyError: | ||
return node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import textwrap | ||
import unittest | ||
import ast | ||
from typing import cast | ||
from dispatch.experimental.multicolor.template import rewrite_template | ||
|
||
|
||
class TestTemplate(unittest.TestCase): | ||
def test_rewrite_template(self): | ||
self.assert_rewrite( | ||
""" | ||
a | ||
b = c | ||
""", | ||
dict( | ||
a=ast.Expr(ast.Name(id="d", ctx=ast.Load())), | ||
c=ast.Name(id="e", ctx=ast.Load()), | ||
), | ||
""" | ||
d | ||
b = e | ||
""", | ||
) | ||
|
||
def assert_rewrite( | ||
self, template: str, replacements: dict[str, ast.expr | ast.stmt], want: str | ||
): | ||
result = rewrite_template(template, **replacements) | ||
self.assertEqual( | ||
ast.unparse(cast(ast.AST, result)), textwrap.dedent(want).strip() | ||
) |