Skip to content

Commit

Permalink
implement while-else and simpily while handling
Browse files Browse the repository at this point in the history
As title
  • Loading branch information
esc committed Apr 25, 2024
1 parent b7908e5 commit 85a6526
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 92 deletions.
44 changes: 26 additions & 18 deletions numba_rvsdg/core/datastructures/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,27 +347,25 @@ def handle_while(self, node: ast.While) -> None:
# when the previous statement was an if-statement with an empty
# endif_block, for example. This is possible because the Python
# while-loop does not need to modify it's preheader.
if self.current_block.instructions:
# Preallocate header, body and exiting indices.
head_index = self.block_index
body_index = self.block_index + 1
exit_index = self.block_index + 2
self.block_index += 3

self.current_block.set_jump_targets(head_index)
# And create new header block
self.add_block(head_index)
else: # reuse existing current_block
# Preallocate body and exiting indices.
head_index = int(self.current_block.name)
body_index = self.block_index
exit_index = self.block_index + 1
self.block_index += 2

# Preallocate header, body, else and exiting indices.
# (Technically, we could re-use the current block as header if it is
# still empty. We elect to potentially leave a block empty instead,
# since there is a pass to prune empty blocks anyway.)
head_index = self.block_index
body_index = self.block_index + 1
exit_index = self.block_index + 2
else_index = self.block_index + 3
self.block_index += 4

self.current_block.set_jump_targets(head_index)
# And create new header block
self.add_block(head_index)

# Emit comparison expression into header.
self.current_block.instructions.append(node.test)
# Set the jump targets to be the body and the exiting latch.
self.current_block.set_jump_targets(body_index, exit_index)
# Set the jump targets to be the body and the else branch.
self.current_block.set_jump_targets(body_index, else_index)

# Create body block.
self.add_block(body_index)
Expand All @@ -388,6 +386,16 @@ def handle_while(self, node: ast.While) -> None:
loop_indices.head == head_index and loop_indices.exit == exit_index
)

# Create else block.
self.add_block(else_index)

# Recurs into the body of the else-branch, again this may modify the
# current_block.
self.codegen(node.orelse)

# Seal current_block.
self.seal_block(exit_index)

# Create exit block and leave open for modifictaion.
self.add_block(exit_index)

Expand Down
189 changes: 115 additions & 74 deletions numba_rvsdg/tests/test_ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def function(x: int, a: int, b: int) -> int:
}
self.compare(function, expected, empty={"9", "6"})

def test_simple_loop(self):
def test_simple_while(self):
def function() -> int:
x = 0
while x < 10:
Expand Down Expand Up @@ -353,9 +353,9 @@ def function() -> int:
"name": "3",
},
}
self.compare(function, expected)
self.compare(function, expected, empty={"4"})

def test_nested_loop(self):
def test_nested_while(self):
def function() -> tuple[int, int]:
x, y = 0, 0
while x < 10:
Expand All @@ -373,33 +373,34 @@ def function() -> tuple[int, int]:
},
"1": {
"instructions": ["x < 10"],
"jump_targets": ["2", "3"],
"jump_targets": ["5", "3"],
"name": "1",
},
"2": {
"instructions": ["y < 5"],
"jump_targets": ["4", "5"],
"name": "2",
},
"3": {
"instructions": ["return (x, y)"],
"jump_targets": [],
"name": "3",
},
"4": {
"5": {
"instructions": ["y < 5"],
"jump_targets": ["6", "7"],
"name": "5",
},
"6": {
"instructions": ["x += 1", "y += 1"],
"jump_targets": ["2"],
"name": "4",
"jump_targets": ["5"],
"name": "6",
},
"5": {
"7": {
"instructions": ["x += 1"],
"jump_targets": ["1"],
"name": "5",
"name": "7",
},
}
self.compare(function, expected)

def test_if_in_loop(self):
self.compare(function, expected, empty={"2", "4", "8"})

def test_if_in_while(self):
def function() -> int:
x = 0
while x < 10:
Expand All @@ -422,28 +423,28 @@ def function() -> int:
},
"2": {
"instructions": ["x < 5"],
"jump_targets": ["4", "5"],
"jump_targets": ["5", "6"],
"name": "2",
},
"3": {
"instructions": ["return x"],
"jump_targets": [],
"name": "3",
},
"4": {
"5": {
"instructions": ["x += 2"],
"jump_targets": ["1"],
"name": "4",
"name": "5",
},
"5": {
"6": {
"instructions": ["x += 1"],
"jump_targets": ["1"],
"name": "5",
"name": "6",
},
}
self.compare(function, expected, empty={"6"})
self.compare(function, expected, empty={"4", "7"})

def test_loop_in_if(self):
def test_while_in_if(self):
def function(a: bool) -> int:
x = 0
if a is True:
Expand All @@ -457,38 +458,40 @@ def function(a: bool) -> int:
expected = {
"0": {
"instructions": ["x = 0", "a is True"],
"jump_targets": ["1", "2"],
"jump_targets": ["4", "8"],
"name": "0",
},
"1": {
"instructions": ["x < 10"],
"jump_targets": ["4", "3"],
"name": "1",
},
"2": {
"instructions": ["x < 10"],
"jump_targets": ["6", "3"],
"name": "2",
},
"3": {
"instructions": ["return x"],
"jump_targets": [],
"name": "3",
},
"4": {
"instructions": ["x += 2"],
"jump_targets": ["1"],
"instructions": ["x < 10"],
"jump_targets": ["5", "3"],
"name": "4",
},
"6": {
"5": {
"instructions": ["x += 2"],
"jump_targets": ["4"],
"name": "5",
},
"8": {
"instructions": ["x < 10"],
"jump_targets": ["9", "3"],
"name": "8",
},
"9": {
"instructions": ["x += 1"],
"jump_targets": ["2"],
"name": "6",
"jump_targets": ["8"],
"name": "9",
},
}
self.compare(function, expected, empty={"5", "7"})
self.compare(
function, expected, empty={"1", "2", "6", "7", "10", "11"}
)

def test_loop_break_continue(self):
def test_while_break_continue(self):
def function() -> int:
x = 0
while x < 10:
Expand All @@ -514,26 +517,64 @@ def function() -> int:
},
"2": {
"instructions": ["x += 1", "x % 2 == 0"],
"jump_targets": ["1", "5"],
"jump_targets": ["1", "6"],
"name": "2",
},
"3": {
"instructions": ["return x"],
"jump_targets": [],
"name": "3",
},
"5": {
"6": {
"instructions": ["x == 9"],
"jump_targets": ["3", "8"],
"name": "5",
"jump_targets": ["3", "9"],
"name": "6",
},
"8": {
"9": {
"instructions": ["x += 1"],
"jump_targets": ["1"],
"name": "8",
"name": "9",
},
}
self.compare(function, expected, empty={"4", "5", "7", "8", "10"})

def test_while_else(self):
def function() -> int:
x = 0
while x < 10:
x += 1
else:
x += 1
return x

expected = {
"0": {
"instructions": ["x = 0"],
"jump_targets": ["1"],
"name": "0",
},
"1": {
"instructions": ["x < 10"],
"jump_targets": ["2", "4"],
"name": "1",
},
"2": {
"instructions": ["x += 1"],
"jump_targets": ["1"],
"name": "2",
},
"3": {
"instructions": ["return x"],
"jump_targets": [],
"name": "3",
},
"4": {
"instructions": ["x += 1"],
"jump_targets": ["3"],
"name": "4",
},
}
self.compare(function, expected, empty={"4", "6", "7", "9"})
self.compare(function, expected)

def test_simple_for(self):
def function() -> int:
Expand Down Expand Up @@ -911,50 +952,50 @@ def function(a: int, b: int, c: int, d: int, e: int, f: int) -> int:
"jump_targets": ["1"],
"name": "11",
},
"12": {
"14": {
"instructions": ["i < 10"],
"jump_targets": ["14", "1"],
"name": "12",
"jump_targets": ["15", "1"],
"name": "14",
},
"14": {
"15": {
"instructions": ["i += 1", "i == d"],
"jump_targets": ["16", "17"],
"name": "14",
"jump_targets": ["18", "19"],
"name": "15",
},
"16": {
"18": {
"instructions": ["i = 3", "return i"],
"jump_targets": [],
"name": "16",
},
"17": {
"instructions": ["i == e"],
"jump_targets": ["19", "20"],
"name": "17",
"name": "18",
},
"19": {
"instructions": ["i = 4"],
"jump_targets": ["1"],
"instructions": ["i == e"],
"jump_targets": ["21", "22"],
"name": "19",
},
"2": {
"instructions": ["i == a"],
"jump_targets": ["5", "6"],
"name": "2",
},
"20": {
"instructions": ["i == f"],
"jump_targets": ["22", "23"],
"name": "20",
"21": {
"instructions": ["i = 4"],
"jump_targets": ["1"],
"name": "21",
},
"22": {
"instructions": ["i = 5"],
"jump_targets": ["12"],
"instructions": ["i == f"],
"jump_targets": ["24", "25"],
"name": "22",
},
"23": {
"24": {
"instructions": ["i = 5"],
"jump_targets": ["14"],
"name": "24",
},
"25": {
"instructions": ["i += 1"],
"jump_targets": ["12"],
"name": "23",
"jump_targets": ["14"],
"name": "25",
},
"3": {
"instructions": ["i = __iter_last_1__"],
Expand Down Expand Up @@ -983,11 +1024,11 @@ def function(a: int, b: int, c: int, d: int, e: int, f: int) -> int:
},
"9": {
"instructions": ["i == c"],
"jump_targets": ["11", "12"],
"jump_targets": ["11", "14"],
"name": "9",
},
}
empty = {"7", "10", "13", "15", "18", "21", "24"}
empty = {"7", "10", "12", "13", "16", "17", "20", "23", "26"}
self.compare(function, expected, empty=empty)


Expand Down

0 comments on commit 85a6526

Please sign in to comment.