Skip to content

Commit

Permalink
allow source as input for the transformer
Browse files Browse the repository at this point in the history
As title
  • Loading branch information
esc committed Apr 25, 2024
1 parent 3563570 commit 7bb82cf
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions numba_rvsdg/core/datastructures/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class AST2SCFGTransformer:
# Prune noop statements and unreachable/empty blocks from the CFG.
prune: bool
# The code to be transformed.
code: Callable[..., Any]
code: str | Callable[..., Any]
# Monotonically increasing block index, starts at 1.
block_index: int
# The current block being modified
Expand All @@ -209,16 +209,32 @@ class AST2SCFGTransformer:
# Stack for header and exiting block of current loop.
loop_stack: list[LoopIndices]

def __init__(self, code: Callable[..., Any], prune: bool = True) -> None:
def __init__(
self, code: str | Callable[..., Any], prune: bool = True
) -> None:
self.prune = prune
self.code = code
self.tree = self.unparse_code(code)
self.block_index: int = 1 # 0 is reserved for genesis block
self.blocks = ASTCFG()
# Initialize first (genesis) block, assume it's named zero.
# (This also initializes the self.current_block attribute.)
self.add_block(0)
self.loop_stack: list[LoopIndices] = []

def unparse_code(self, code: str | Callable[..., Any]) -> list[ast.AST]:
# Convert source code into AST.
if isinstance(self.code, str):
tree = ast.parse(self.code).body
elif callable(self.code):
tree = ast.parse(
textwrap.dedent(inspect.getsource(self.code))
).body
else:
msg = "Type: '{type(self.cod}}' is not implemented."
raise NotImplementedError(msg)
return tree # type: ignore

def transform_to_ASTCFG(self) -> ASTCFG:
"""Generate ASTCFG from Python function."""
self.transform()
Expand Down Expand Up @@ -248,13 +264,11 @@ def seal_block(self, default_index: int) -> None:

def transform(self) -> None:
"""Transform Python function stored as self.code."""
# Convert source code into AST.
tree = ast.parse(textwrap.dedent(inspect.getsource(self.code))).body
# Assert that the code handed in was a function, we can only transform
# functions.
assert isinstance(tree[0], ast.FunctionDef)
assert isinstance(self.tree[0], ast.FunctionDef)
# Run recursive code generation.
self.codegen(tree)
self.codegen(self.tree)
# Prune if requested.
if self.prune:
_ = self.blocks.prune_unreachable()
Expand Down Expand Up @@ -631,7 +645,7 @@ def render(self) -> None:
self.blocks.to_SCFG().render()


def AST2SCFG(code: Callable[..., Any]) -> SCFG:
def AST2SCFG(code: str | Callable[..., Any]) -> SCFG:
"""Transform Python function into an SCFG."""
return AST2SCFGTransformer(code).transform_to_SCFG()

Expand Down

0 comments on commit 7bb82cf

Please sign in to comment.