Skip to content

Commit

Permalink
cfg: Updated raise handling in try-except blocks (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
sushimon authored Apr 8, 2023
1 parent 4895648 commit 1551fc2
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- Fixed bug in possibly-undefined checker where a comprehension variable is falsely flagged as possibly undefined.
- Fixed bug where `check_errors` and `check_all` opens a webpage when a nonexistent or unreadable path is passed as an argument.
- Fixed the CFG implementation of pyta to resolve a bug in the possibly-undefined checker where variables were falsely flagged as possibly undefined when the code conditionally raises an exception and the variable was referenced afterwards.
- Fixed bug where the generated CFGs will highlight the except block as unreachable if the same exception it is handling was raised in the body of the tryexcept.

### New checkers

Expand Down
80 changes: 75 additions & 5 deletions python_ta/cfg/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def _visit_jump(
) -> None:
old_curr = self._current_block
for boundary, exits in reversed(self._control_boundaries):
if isinstance(node, nodes.Raise):
exc_name = _get_raise_exc(node)

if exc_name in exits:
self._current_cfg.link(old_curr, exits[exc_name])
old_curr.add_statement(node)
break

if type(node).__name__ in exits:
self._current_cfg.link(old_curr, exits[type(node).__name__])
old_curr.add_statement(node)
Expand All @@ -253,17 +261,34 @@ def visit_tryexcept(self, node: nodes.TryExcept) -> None:

node.cfg_block = self._current_block

for child in node.body:
child.accept(self)
end_body = self._current_block

# Construct the exception handlers first
# Initialize a temporary block to later merge with end_body
self._current_block = self._current_cfg.create_block()
temp = self._current_block
end_block = self._current_cfg.create_block()
# Case where Raise is not handled in tryexcept
self._control_boundaries.append((node, {nodes.Raise.__name__: end_block}))
cbs_added = 1

after_body = []
for handler in node.handlers:
# Construct blocks in reverse to give precedence to the first block in overlapping except
# branches
for handler in reversed(node.handlers):
h = self._current_cfg.create_block()
self._current_block = h
handler.cfg_block = h

exceptions = _extract_exceptions(handler)
# Edge case: catch-all except clause (i.e. except: ...)
if exceptions == []:
self._control_boundaries.append((node, {nodes.Raise.__name__: h}))
cbs_added += 1

# General case: specific except clause
for exception in exceptions:
self._control_boundaries.append((node, {f"{nodes.Raise.__name__} {exception}": h}))
cbs_added += 1

if handler.name is not None: # The name assigned to the caught exception.
handler.name.accept(self)
for child in handler.body:
Expand All @@ -281,6 +306,18 @@ def visit_tryexcept(self, node: nodes.TryExcept) -> None:
child.accept(self)
self._current_cfg.link_or_merge(self._current_block, end_block)

# Construct the try body so reset current block to this node's block
self._current_block = node.cfg_block

for child in node.body:
child.accept(self)
end_body = self._current_block

# Remove each control boundary that we added in this method
for _ in range(cbs_added):
self._control_boundaries.pop()

self._current_cfg.link_or_merge(temp, end_body)
self._current_cfg.multiple_link_or_merge(end_body, after_body)
self._current_block = end_block

Expand All @@ -292,3 +329,36 @@ def visit_with(self, node: nodes.With) -> None:

for child in node.body:
child.accept(self)


def _extract_exceptions(node: nodes.ExceptHandler) -> List[str]:
"""A helper method that returns a list of all the exceptions handled by this except block as a
list of strings.
"""
exceptions = node.type
exceptions_so_far = []
# ExceptHandler.type will either be Tuple, NodeNG, or None.
if exceptions is None:
return exceptions_so_far

# Get all the Name nodes for all exceptions this except block is handling
for exception in exceptions.nodes_of_class(nodes.Name):
exceptions_so_far.append(exception.name)

return exceptions_so_far


def _get_raise_exc(node: nodes.Raise) -> str:
"""A helper method that returns a string formatted for the control boundary representing the
exception that this Raise node throws.
Preconditions:
- the raise statement is of the form 'raise' or 'raise <exception_class>'
"""
exceptions = node.nodes_of_class(nodes.Name)

# Return the formatted name of the exception or the just 'Raise' otherwise
try:
return f"{nodes.Raise.__name__} {next(exceptions).name}"
except StopIteration:
return nodes.Raise.__name__
102 changes: 101 additions & 1 deletion tests/test_cfg/test_tryexcept.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def _extract_blocks(cfg: ControlFlowGraph) -> List[List[str]]:
return [[s.as_string() for s in block.statements] for block in cfg.get_blocks()]


def _extract_unreachable_blocks(cfg: ControlFlowGraph) -> List[List[str]]:
return [[s.as_string() for s in block.statements] for block in cfg.unreachable_blocks]


def test_simple() -> None:
src = """
try:
Expand Down Expand Up @@ -63,7 +67,7 @@ def test_multiple_exceptions() -> None:
else:
print('else')
"""
expected_blocks = [["print(True)"], ["pass"], [], ["k", "pass"], ["print('else')"]] # end block
expected_blocks = [["print(True)"], ["k", "pass"], [], ["pass"], ["print('else')"]] # end block
assert _extract_blocks(build_cfg(src)) == expected_blocks


Expand Down Expand Up @@ -138,3 +142,99 @@ def test_complex() -> None:
["pass"],
]
assert _extract_blocks(build_cfg(src)) == expected_blocks


def test_raise_in_tryexcept() -> None:
"""Test that the try-body is correctly linked to the except block when the same exception is
raised in the try-body."""
src = """
try:
raise NotImplementedError
except NotImplementedError:
pass
"""
expected_unreachable_blocks = []
assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks


def test_overlapping_exceptions() -> None:
"""Test that the try-body is linked to the correct except block when there are duplicate except
clauses."""
src = """
try:
raise RuntimeError
except RuntimeError:
print('oh no!')
except RuntimeError:
pass
"""
expected_unreachable_blocks = [["pass"]]
assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks


def test_catch_all_exception() -> None:
"""Test that the try-body is correctly linked to the catch-all except block."""
src = """
try:
raise IOError
except NotImplementedError:
pass
except:
print('heh')
"""
expected_unreachable_blocks = [["pass"]]
assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks


def test_no_link_catch_all() -> None:
"""Test that the try-body correctly links to its corresponding except block even when there's a
general exception clause."""
src = """
try:
raise NotImplementedError
except NotImplementedError:
pass
except:
print('heh')
"""
expected_unreachable_blocks = [["print('heh')"]]
assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks


def test_nameless_node_raise_exceptions() -> None:
"""Test that the try-body correctly links to the end block when the raise statement's exception
is of type None with no general except handler block."""
src = """
try:
raise
except NotImplementedError:
pass
"""
expected_unreachable_blocks = [["pass"]]
assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks


def test_nameless_node_with_general_catch() -> None:
"""Test that the try-body correctly links to the general catch all block when the raise
statement's exception has no .name attribute."""
src = """
try:
raise
except:
pass
"""
expected_unreachable_blocks = []
assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks


def test_exception_with_args() -> None:
"""Test that the try-body correctly links to its corresponding except block when an exception
instance is passed."""
src = """
try:
raise NotImplementedError('oh no!')
except NotImplementedError:
pass
"""
expected_unreachable_blocks = []
assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks

0 comments on commit 1551fc2

Please sign in to comment.