Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementation of an AST -> SCFG transformer #114

Merged
merged 22 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 91 additions & 39 deletions numba_rvsdg/core/datastructures/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class WritableASTBlock:

"""

name: str
instructions: list[ast.AST]
jump_targets: list[str]

def __init__(
self,
name: str,
Expand Down Expand Up @@ -70,10 +74,8 @@ def seal_inside_loop(
"""
if self.is_continue():
self.set_jump_targets(head_index)
self.instructions.pop()
elif self.is_break():
self.set_jump_targets(exit_index)
self.instructions.pop()
elif self.is_return():
pass
else:
Expand All @@ -89,13 +91,17 @@ def to_dict(self) -> dict[str, Any]:
def __repr__(self) -> str:
return (
f"WritableASTBlock({self.name}, "
"{self.instructions}, {self.jump_targets})"
f"{self.instructions}, {self.jump_targets})"
)


class ASTCFG(dict[str, WritableASTBlock]):
"""A CFG consisting of WritableASTBlocks."""

unreachable: set[WritableASTBlock]
empty: set[WritableASTBlock]
noops: set[type[ast.AST]]

def convert_blocks(self) -> MutableMapping[str, Any]:
"""Convert WritableASTBlocks to PythonASTBlocks."""
return {
Expand Down Expand Up @@ -134,6 +140,20 @@ def prune_unreachable(self) -> set[WritableASTBlock]:
self.unreachable = unreachable
return unreachable

def prune_noops(self) -> set[type[ast.AST]]:
"""Prune no-op instructions from the CFG."""
noops = set()
exclude = (ast.Pass, ast.Continue, ast.Break)
for block in self.values():
block.instructions = [
i for i in block.instructions if not isinstance(i, exclude)
]
noops.update(
[i for i in block.instructions if isinstance(i, exclude)]
)
self.noops = noops # type: ignore
return noops # type: ignore

def prune_empty(self) -> set[WritableASTBlock]:
"""Prune empty blocks from the CFG."""
empty = set()
Expand Down Expand Up @@ -175,22 +195,46 @@ class AST2SCFGTransformer:

"""

def __init__(self, code: Callable[..., Any], prune: bool = True) -> None:
# Prune empty and unreachable blocks from the CFG.
self.prune: bool = prune
# Save the code for transformation.
self.code: Callable[..., Any] = code
# Monotonically increasing block index, 0 is reserved for genesis.
self.block_index: int = 1
# Dict mapping block indices as strings to WritableASTBlocks.
# (This is the data structure to hold the CFG.)
self.blocks: ASTCFG = ASTCFG()
# Prune noop statements and unreachable/empty blocks from the CFG.
prune: bool
# The code to be transformed.
code: str | Callable[..., Any]
tree: list[type[ast.AST]]
# Monotonically increasing block index, starts at 1.
block_index: int
# The current block being modified
current_block: WritableASTBlock
# Dict mapping block indices as strings to WritableASTBlocks.
# (This is the data structure to hold the CFG.)
blocks: ASTCFG
# Stack for header and exiting block of current loop.
loop_stack: list[LoopIndices]

def __init__(
self, code: str | Callable[..., Any], prune: bool = True
esc marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.prune = prune
self.code = code
self.tree = AST2SCFGTransformer.unparse_code(code)
self.block_index: int = 1 # 0 is reserved for genesis block
self.blocks = ASTCFG()
# Initialize first (genesis) block, assume it's named zero.
# (This also initializes the self.current_block attribute.)
self.add_block(0)
# Stack for header and exiting block of current loop.
self.loop_stack: list[LoopIndices] = []

@staticmethod
def unparse_code(code: str | Callable[..., Any]) -> list[type[ast.AST]]:
# Convert source code into AST.
if isinstance(code, str):
tree = ast.parse(code).body
esc marked this conversation as resolved.
Show resolved Hide resolved
elif callable(code):
tree = ast.parse(textwrap.dedent(inspect.getsource(code))).body
else:
msg = "Type: '{type(self.code}}' is not implemented."
raise NotImplementedError(msg)
return tree # type: ignore

def transform_to_ASTCFG(self) -> ASTCFG:
"""Generate ASTCFG from Python function."""
self.transform()
Expand Down Expand Up @@ -220,16 +264,15 @@ def seal_block(self, default_index: int) -> None:

def transform(self) -> None:
"""Transform Python function stored as self.code."""
# Convert source code into AST.
tree = ast.parse(textwrap.dedent(inspect.getsource(self.code))).body
# Assert that the code handed in was a function, we can only transform
# functions.
assert isinstance(tree[0], ast.FunctionDef)
assert isinstance(self.tree[0], ast.FunctionDef)
# Run recursive code generation.
self.codegen(tree)
self.codegen(self.tree)
# Prune if requested.
if self.prune:
_ = self.blocks.prune_unreachable()
_ = self.blocks.prune_noops()
_ = self.blocks.prune_empty()

def codegen(self, tree: list[type[ast.AST]] | list[ast.stmt]) -> None:
Expand Down Expand Up @@ -257,6 +300,7 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None:
ast.Return,
ast.Break,
ast.Continue,
ast.Pass,
),
):
self.current_block.instructions.append(node)
Expand Down Expand Up @@ -317,27 +361,25 @@ def handle_while(self, node: ast.While) -> None:
# when the previous statement was an if-statement with an empty
# endif_block, for example. This is possible because the Python
# while-loop does not need to modify it's preheader.
if self.current_block.instructions:
# Preallocate header, body and exiting indices.
head_index = self.block_index
body_index = self.block_index + 1
exit_index = self.block_index + 2
self.block_index += 3

self.current_block.set_jump_targets(head_index)
# And create new header block
self.add_block(head_index)
else: # reuse existing current_block
# Preallocate body and exiting indices.
head_index = int(self.current_block.name)
body_index = self.block_index
exit_index = self.block_index + 1
self.block_index += 2

# Preallocate header, body, else and exiting indices.
# (Technically, we could re-use the current block as header if it is
# still empty. We elect to potentially leave a block empty instead,
# since there is a pass to prune empty blocks anyway.)
head_index = self.block_index
body_index = self.block_index + 1
exit_index = self.block_index + 2
else_index = self.block_index + 3
self.block_index += 4

self.current_block.set_jump_targets(head_index)
# And create new header block
self.add_block(head_index)

# Emit comparison expression into header.
self.current_block.instructions.append(node.test)
# Set the jump targets to be the body and the exiting latch.
self.current_block.set_jump_targets(body_index, exit_index)
# Set the jump targets to be the body and the else branch.
self.current_block.set_jump_targets(body_index, else_index)

# Create body block.
self.add_block(body_index)
Expand All @@ -358,6 +400,16 @@ def handle_while(self, node: ast.While) -> None:
loop_indices.head == head_index and loop_indices.exit == exit_index
)

# Create else block.
self.add_block(else_index)

# Recurs into the body of the else-branch, again this may modify the
# current_block.
self.codegen(node.orelse)

# Seal current_block.
self.seal_block(exit_index)

# Create exit block and leave open for modifictaion.
self.add_block(exit_index)

Expand Down Expand Up @@ -577,8 +629,8 @@ def function(a: int) -> None
# Recurs into the body of the else-branch.
self.codegen(node.orelse)

# Set jump_target of current block, whatever it may be.
self.current_block.set_jump_targets(exit_index)
# Seal current block, whatever it may be.
self.seal_block(exit_index)

# Create exit block and leave open for modification
self.add_block(exit_index)
Expand All @@ -593,7 +645,7 @@ def render(self) -> None:
self.blocks.to_SCFG().render()


def AST2SCFG(code: Callable[..., Any]) -> SCFG:
def AST2SCFG(code: str | Callable[..., Any]) -> SCFG:
"""Transform Python function into an SCFG."""
return AST2SCFGTransformer(code).transform_to_SCFG()

Expand Down
Loading
Loading