diff --git a/CHANGELOG.md b/CHANGELOG.md index 73217d546..092510fe9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Fixed bug in possibly-undefined checker where a comprehension variable is falsely flagged as possibly undefined. - Fixed bug where `check_errors` and `check_all` opens a webpage when a nonexistent or unreadable path is passed as an argument. - Fixed the CFG implementation of pyta to resolve a bug in the possibly-undefined checker where variables were falsely flagged as possibly undefined when the code conditionally raises an exception and the variable was referenced afterwards. +- Fixed bug where the generated CFGs will highlight the except block as unreachable if the same exception it is handling was raised in the body of the tryexcept. ### New checkers diff --git a/python_ta/cfg/visitor.py b/python_ta/cfg/visitor.py index 723595e0d..d78f49cb9 100644 --- a/python_ta/cfg/visitor.py +++ b/python_ta/cfg/visitor.py @@ -235,6 +235,14 @@ def _visit_jump( ) -> None: old_curr = self._current_block for boundary, exits in reversed(self._control_boundaries): + if isinstance(node, nodes.Raise): + exc_name = _get_raise_exc(node) + + if exc_name in exits: + self._current_cfg.link(old_curr, exits[exc_name]) + old_curr.add_statement(node) + break + if type(node).__name__ in exits: self._current_cfg.link(old_curr, exits[type(node).__name__]) old_curr.add_statement(node) @@ -253,17 +261,34 @@ def visit_tryexcept(self, node: nodes.TryExcept) -> None: node.cfg_block = self._current_block - for child in node.body: - child.accept(self) - end_body = self._current_block - + # Construct the exception handlers first + # Initialize a temporary block to later merge with end_body + self._current_block = self._current_cfg.create_block() + temp = self._current_block end_block = self._current_cfg.create_block() + # Case where Raise is not handled in tryexcept + self._control_boundaries.append((node, {nodes.Raise.__name__: end_block})) + cbs_added = 1 after_body = [] - for handler in node.handlers: + # Construct blocks in reverse to give precedence to the first block in overlapping except + # branches + for handler in reversed(node.handlers): h = self._current_cfg.create_block() self._current_block = h handler.cfg_block = h + + exceptions = _extract_exceptions(handler) + # Edge case: catch-all except clause (i.e. except: ...) + if exceptions == []: + self._control_boundaries.append((node, {nodes.Raise.__name__: h})) + cbs_added += 1 + + # General case: specific except clause + for exception in exceptions: + self._control_boundaries.append((node, {f"{nodes.Raise.__name__} {exception}": h})) + cbs_added += 1 + if handler.name is not None: # The name assigned to the caught exception. handler.name.accept(self) for child in handler.body: @@ -281,6 +306,18 @@ def visit_tryexcept(self, node: nodes.TryExcept) -> None: child.accept(self) self._current_cfg.link_or_merge(self._current_block, end_block) + # Construct the try body so reset current block to this node's block + self._current_block = node.cfg_block + + for child in node.body: + child.accept(self) + end_body = self._current_block + + # Remove each control boundary that we added in this method + for _ in range(cbs_added): + self._control_boundaries.pop() + + self._current_cfg.link_or_merge(temp, end_body) self._current_cfg.multiple_link_or_merge(end_body, after_body) self._current_block = end_block @@ -292,3 +329,36 @@ def visit_with(self, node: nodes.With) -> None: for child in node.body: child.accept(self) + + +def _extract_exceptions(node: nodes.ExceptHandler) -> List[str]: + """A helper method that returns a list of all the exceptions handled by this except block as a + list of strings. + """ + exceptions = node.type + exceptions_so_far = [] + # ExceptHandler.type will either be Tuple, NodeNG, or None. + if exceptions is None: + return exceptions_so_far + + # Get all the Name nodes for all exceptions this except block is handling + for exception in exceptions.nodes_of_class(nodes.Name): + exceptions_so_far.append(exception.name) + + return exceptions_so_far + + +def _get_raise_exc(node: nodes.Raise) -> str: + """A helper method that returns a string formatted for the control boundary representing the + exception that this Raise node throws. + + Preconditions: + - the raise statement is of the form 'raise' or 'raise ' + """ + exceptions = node.nodes_of_class(nodes.Name) + + # Return the formatted name of the exception or the just 'Raise' otherwise + try: + return f"{nodes.Raise.__name__} {next(exceptions).name}" + except StopIteration: + return nodes.Raise.__name__ diff --git a/tests/test_cfg/test_tryexcept.py b/tests/test_cfg/test_tryexcept.py index 84ad9fb38..f6e78e4dc 100644 --- a/tests/test_cfg/test_tryexcept.py +++ b/tests/test_cfg/test_tryexcept.py @@ -17,6 +17,10 @@ def _extract_blocks(cfg: ControlFlowGraph) -> List[List[str]]: return [[s.as_string() for s in block.statements] for block in cfg.get_blocks()] +def _extract_unreachable_blocks(cfg: ControlFlowGraph) -> List[List[str]]: + return [[s.as_string() for s in block.statements] for block in cfg.unreachable_blocks] + + def test_simple() -> None: src = """ try: @@ -63,7 +67,7 @@ def test_multiple_exceptions() -> None: else: print('else') """ - expected_blocks = [["print(True)"], ["pass"], [], ["k", "pass"], ["print('else')"]] # end block + expected_blocks = [["print(True)"], ["k", "pass"], [], ["pass"], ["print('else')"]] # end block assert _extract_blocks(build_cfg(src)) == expected_blocks @@ -138,3 +142,99 @@ def test_complex() -> None: ["pass"], ] assert _extract_blocks(build_cfg(src)) == expected_blocks + + +def test_raise_in_tryexcept() -> None: + """Test that the try-body is correctly linked to the except block when the same exception is + raised in the try-body.""" + src = """ + try: + raise NotImplementedError + except NotImplementedError: + pass + """ + expected_unreachable_blocks = [] + assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks + + +def test_overlapping_exceptions() -> None: + """Test that the try-body is linked to the correct except block when there are duplicate except + clauses.""" + src = """ + try: + raise RuntimeError + except RuntimeError: + print('oh no!') + except RuntimeError: + pass + """ + expected_unreachable_blocks = [["pass"]] + assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks + + +def test_catch_all_exception() -> None: + """Test that the try-body is correctly linked to the catch-all except block.""" + src = """ + try: + raise IOError + except NotImplementedError: + pass + except: + print('heh') + """ + expected_unreachable_blocks = [["pass"]] + assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks + + +def test_no_link_catch_all() -> None: + """Test that the try-body correctly links to its corresponding except block even when there's a + general exception clause.""" + src = """ + try: + raise NotImplementedError + except NotImplementedError: + pass + except: + print('heh') + """ + expected_unreachable_blocks = [["print('heh')"]] + assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks + + +def test_nameless_node_raise_exceptions() -> None: + """Test that the try-body correctly links to the end block when the raise statement's exception + is of type None with no general except handler block.""" + src = """ + try: + raise + except NotImplementedError: + pass + """ + expected_unreachable_blocks = [["pass"]] + assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks + + +def test_nameless_node_with_general_catch() -> None: + """Test that the try-body correctly links to the general catch all block when the raise + statement's exception has no .name attribute.""" + src = """ + try: + raise + except: + pass + """ + expected_unreachable_blocks = [] + assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks + + +def test_exception_with_args() -> None: + """Test that the try-body correctly links to its corresponding except block when an exception + instance is passed.""" + src = """ + try: + raise NotImplementedError('oh no!') + except NotImplementedError: + pass + """ + expected_unreachable_blocks = [] + assert _extract_unreachable_blocks(build_cfg(src)) == expected_unreachable_blocks