diff --git a/python_ta/cfg/visitor.py b/python_ta/cfg/visitor.py index d78f49cb9..3eb4a37c9 100644 --- a/python_ta/cfg/visitor.py +++ b/python_ta/cfg/visitor.py @@ -160,7 +160,7 @@ def visit_while(self, node: nodes.While) -> None: ) # Handle "body" branch - body_block = self._current_cfg.create_block(test_block) + body_block = self._current_cfg.create_block(test_block, edge_label="True") self._current_block = body_block for child in node.body: child.accept(self) @@ -171,7 +171,7 @@ def visit_while(self, node: nodes.While) -> None: self._control_boundaries.pop() # Handle "else" branch - else_block = self._current_cfg.create_block(test_block) + else_block = self._current_cfg.create_block(test_block, edge_label="False") self._current_block = else_block for child in node.orelse: child.accept(self) diff --git a/sample_usage/draw_cfg.py b/sample_usage/draw_cfg.py index bc04a0fb5..49b589216 100644 --- a/sample_usage/draw_cfg.py +++ b/sample_usage/draw_cfg.py @@ -8,7 +8,8 @@ from python_ta.cfg import CFGBlock, CFGVisitor, ControlFlowGraph USAGE = "USAGE: python -m sample_usage.draw_cfg " -GRAPH_OPTIONS = {"format": "jpg", "node_attr": {"shape": "box", "fontname": "Courier New"}} +GRAPH_OPTIONS = {"format": "svg", "node_attr": {"shape": "box", "fontname": "Courier New"}} +SUBGRAPH_OPTIONS = {"fontname": "Courier New"} def display(cfgs: Dict[nodes.NodeNG, ControlFlowGraph], filename: str, view: bool = True) -> None: @@ -22,15 +23,15 @@ def display(cfgs: Dict[nodes.NodeNG, ControlFlowGraph], filename: str, view: boo continue with graph.subgraph(name=f"cluster_{id(node)}") as c: visited = set() - _visit(cfg.start, c, visited) + _visit(cfg.start, c, visited, cfg.end) for block in cfg.unreachable_blocks: - _visit(block, c, visited) - c.attr(label=subgraph_label) + _visit(block, c, visited, cfg.end) + c.attr(label=subgraph_label, **SUBGRAPH_OPTIONS) graph.render(filename, view=view) -def _visit(block: CFGBlock, graph: graphviz.Digraph, visited: Set[int]) -> None: +def _visit(block: CFGBlock, graph: graphviz.Digraph, visited: Set[int], end: CFGBlock) -> None: node_id = f"{graph.name}_{block.id}" if node_id in visited: return @@ -42,6 +43,8 @@ def _visit(block: CFGBlock, graph: graphviz.Digraph, visited: Set[int]) -> None: label = label.replace("\n", "\\l") fill_color = "grey93" if not block.reachable else "white" + # Change the fill colour if block is the end of the cfg + fill_color = "black" if block == end else fill_color graph.node(node_id, label=label, fillcolor=fill_color, style="filled") visited.add(node_id) @@ -51,7 +54,7 @@ def _visit(block: CFGBlock, graph: graphviz.Digraph, visited: Set[int]) -> None: graph.edge(node_id, f"{graph.name}_{edge.target.id}", str(edge.label)) else: graph.edge(node_id, f"{graph.name}_{edge.target.id}") - _visit(edge.target, graph, visited) + _visit(edge.target, graph, visited, end) def main(filepath: str) -> None: diff --git a/tests/test_cfg/test_label_while.py b/tests/test_cfg/test_label_while.py new file mode 100644 index 000000000..86cace8d0 --- /dev/null +++ b/tests/test_cfg/test_label_while.py @@ -0,0 +1,161 @@ +from typing import Set + +import astroid + +from python_ta.cfg import CFGVisitor, ControlFlowGraph + + +def build_cfg(src: str) -> ControlFlowGraph: + """Build a CFG for testing.""" + mod = astroid.parse(src) + t = CFGVisitor() + mod.accept(t) + + return t.cfgs[mod] + + +def _extract_labels(cfg: ControlFlowGraph) -> Set[str]: + """Return a set of all the labels in this cfg.""" + labels = {edge.label for edge in cfg.get_edges() if edge.label is not None} + return labels + + +def _extract_num_labels(cfg: ControlFlowGraph) -> int: + """Return the number of labelled edges in the cfg.""" + return sum(1 for edge in cfg.get_edges() if edge.label is not None) + + +def test_num_while_labels() -> None: + """Test that the expected number of labels is produced in a while loop.""" + src = """ + i = 0 + while i < 10: + i += 1 + + print('not else') + """ + expected_num_labels = 2 + assert _extract_num_labels(build_cfg(src)) == expected_num_labels + + +def test_type_while_labels() -> None: + """Test that the content of the labels produced in a while loop is correct.""" + src = """ + i = 0 + while i < 10: + i += 1 + + print('not else') + """ + expected_labels = {"True", "False"} + assert _extract_labels(build_cfg(src)) == expected_labels + + +def test_num_while_else_labels() -> None: + """Test that the expected number of labels is produced in a while-else loop.""" + src = """ + i = 0 + while i < 10: + i += 1 + else: + print('is else') + + print('not else') + """ + expected_num_labels = 2 + assert _extract_num_labels(build_cfg(src)) == expected_num_labels + + +def test_type_while_else_labels() -> None: + """Test that the content of the labels produced in a while-else loop is correct.""" + src = """ + i = 0 + while i < 10: + i += 1 + else: + print('is else') + + print('not else') + """ + expected_labels = {"True", "False"} + assert _extract_labels(build_cfg(src)) == expected_labels + + +def test_num_complex_while_labels() -> None: + """Test that the number of labels in a complex while loop is correct.""" + src = """ + i = 0 + while i < 10: + j = 0 + while j < 5: + j += 1 + i += 1 + + if i > 4: + print('hi') + + print('not else') + """ + expected_num_labels = 6 + assert _extract_num_labels(build_cfg(src)) == expected_num_labels + + +def test_type_complex_while_labels() -> None: + """Test that the content of the labels produced in a complex while loop is correct.""" + src = """ + i = 0 + while i < 10: + j = 0 + while j < 5: + j += 1 + i += 1 + + if i > 4: + print('hi') + + print('not else') + """ + expected_labels = {"True", "False"} + assert _extract_labels(build_cfg(src)) == expected_labels + + +def test_num_complex_while_else_labels() -> None: + """Test that the number of labels in a complex while-else loop is correct.""" + src = """ + i = 0 + while i < 10: + j = 0 + while j < 5: + j += 1 + i += 1 + + if i > 4: + print('hi') + else: + print('is else') + + print('not else') + """ + expected_num_labels = 6 + assert _extract_num_labels(build_cfg(src)) == expected_num_labels + + +def test_type_complex_while_else_labels() -> None: + """Test that the content of the labels produced in a complex while-else loop is correct.""" + src = """ + i = 0 + while i < 10: + j = 0 + while j < 5: + j += 1 + i += 1 + + if i > 4: + print('hi') + else: + print('is else') + + print('not else') + """ + expected_labels = {"True", "False"} + assert _extract_labels(build_cfg(src)) == expected_labels