Skip to content

Commit

Permalink
add ability to convert ast directly
Browse files Browse the repository at this point in the history
As title
  • Loading branch information
esc committed Apr 29, 2024
1 parent 4b7d46b commit 41cc5e3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
18 changes: 14 additions & 4 deletions numba_rvsdg/core/datastructures/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class AST2SCFGTransformer:
# Prune noop statements and unreachable/empty blocks from the CFG.
prune: bool
# The code to be transformed.
code: str | Callable[..., Any]
code: str | list[ast.FunctionDef] | Callable[..., Any]
tree: list[type[ast.AST]]
# Monotonically increasing block index, starts at 1.
block_index: int
Expand All @@ -211,7 +211,9 @@ class AST2SCFGTransformer:
loop_stack: list[LoopIndices]

def __init__(
self, code: str | Callable[..., Any], prune: bool = True
self,
code: str | list[ast.FunctionDef] | Callable[..., Any],
prune: bool = True,
) -> None:
self.prune = prune
self.code = code
Expand All @@ -224,12 +226,20 @@ def __init__(
self.loop_stack: list[LoopIndices] = []

@staticmethod
def unparse_code(code: str | Callable[..., Any]) -> list[type[ast.AST]]:
def unparse_code(
code: str | list[ast.FunctionDef] | Callable[..., Any]
) -> list[type[ast.AST]]:
# Convert source code into AST.
if isinstance(code, str):
tree = ast.parse(code).body
elif callable(code):
tree = ast.parse(textwrap.dedent(inspect.getsource(code))).body
elif (
isinstance(code, list)
and len(code) > 0
and all([isinstance(i, ast.AST) for i in code])
):
tree = code # type: ignore
else:
msg = "Type: '{type(self.code}}' is not implemented."
raise NotImplementedError(msg)
Expand Down Expand Up @@ -645,7 +655,7 @@ def render(self) -> None:
self.blocks.to_SCFG().render()


def AST2SCFG(code: str | Callable[..., Any]) -> SCFG:
def AST2SCFG(code: str | list[ast.FunctionDef] | Callable[..., Any]) -> SCFG:
"""Transform Python function into an SCFG."""
return AST2SCFGTransformer(code).transform_to_SCFG()

Expand Down
23 changes: 21 additions & 2 deletions numba_rvsdg/tests/test_ast_transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: ignore-errors
import ast
import textwrap
from typing import Callable, Any
from unittest import main, TestCase
Expand Down Expand Up @@ -38,10 +39,28 @@ def function() -> int:
self.compare(function, expected)

def test_solo_return_from_string(self):
function = textwrap.dedent("""
function = textwrap.dedent(
"""
def function() -> int:
return 1
""")
"""
)

expected = {
"0": {
"instructions": ["return 1"],
"jump_targets": [],
"name": "0",
}
}
self.compare(function, expected)

def test_solo_return_from_AST(self):
function = ast.parse(textwrap.dedent(
"""
def function() -> int:
return 1
""")).body

expected = {
"0": {
Expand Down

0 comments on commit 41cc5e3

Please sign in to comment.