Skip to content

Commit

Permalink
Add z3 option to CFG-based checkers (#1099)
Browse files Browse the repository at this point in the history
When enabled, these checkers will only consider edges that are feasible (based on the satisfiability of logical constraints determined by Z3).
  • Loading branch information
Raine-Yang-UofT authored Nov 6, 2024
1 parent ba8cfd8 commit 866e31b
Show file tree
Hide file tree
Showing 12 changed files with 486 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- Added `include_frames` filter to `snapshot`
- Added `exclude_vars` filter to `snapshot`
- Added new `python_ta.debug` module with an `SnapshotTracer` context manager for generating memory models
- Added `z3` option to `inconsistent-or-missing-returns`, `redundant-assignment`, and `possibly-undefined` checkers to only check for feasible code blocks based on edge z3 constraints
- Included the name of redundant variable in `E9959 redundant-assignment` message
- Update to pylint v3.3 and and astroid v3.3. This added support for Python 3.13 and dropped support for Python 3.8.
- Added a STRICT_NUMERIC_TYPES configuration to `python_ta.contracts` allowing to enable/disable stricter type checking of numeric types
Expand Down
4 changes: 3 additions & 1 deletion python_ta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def _check(

global PYLINT_PATCHED
if not PYLINT_PATCHED:
patch_all(messages_config) # Monkeypatch pylint (override certain methods)
patch_all(
messages_config, linter.config.z3
) # Monkeypatch pylint (override certain methods)
PYLINT_PATCHED = True

# Try to check file, issue error message for invalid files.
Expand Down
36 changes: 26 additions & 10 deletions python_ta/cfg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,32 +166,44 @@ def multiple_link_or_merge(self, source: CFGBlock, targets: List[CFGBlock]) -> N
for target in targets:
self.link(source, target)

def get_blocks(self) -> Generator[CFGBlock, None, None]:
"""Generate a sequence of all blocks in this graph."""
yield from self._get_blocks(self.start, set())
def get_blocks(self, only_feasible: bool = False) -> Generator[CFGBlock, None, None]:
"""Generate a sequence of all blocks in this graph.
def _get_blocks(self, block: CFGBlock, visited: Set[int]) -> Generator[CFGBlock, None, None]:
When only_feasible is True, only generate blocks feasible from start based on edge z3 constraints.
"""
yield from self._get_blocks(self.start, set(), only_feasible)

def _get_blocks(
self, block: CFGBlock, visited: Set[int], only_feasible: bool
) -> Generator[CFGBlock, None, None]:
if block.id in visited:
return

yield block
visited.add(block.id)

for edge in block.successors:
yield from self._get_blocks(edge.target, visited)
if not only_feasible or edge.is_feasible:
yield from self._get_blocks(edge.target, visited, only_feasible)

def get_blocks_postorder(self) -> Generator[CFGBlock, None, None]:
def get_blocks_postorder(self, only_feasible: bool = False) -> Generator[CFGBlock, None, None]:
"""Return the sequence of all blocks in this graph in the order of
a post-order traversal."""
yield from self._get_blocks_postorder(self.start, set())
a post-order traversal.
When only_feasible is True, only generate blocks feasible from start based on edge z3 constraints.
"""
yield from self._get_blocks_postorder(self.start, set(), only_feasible)

def _get_blocks_postorder(self, block: CFGBlock, visited) -> Generator[CFGBlock, None, None]:
def _get_blocks_postorder(
self, block: CFGBlock, visited: Set[int], only_feasible: bool
) -> Generator[CFGBlock, None, None]:
if block.id in visited:
return

visited.add(block.id)
for succ in block.successors:
yield from self._get_blocks_postorder(succ.target, visited)
if not only_feasible or succ.is_feasible:
yield from self._get_blocks_postorder(succ.target, visited, only_feasible)

yield block

Expand Down Expand Up @@ -353,6 +365,10 @@ def jump(self) -> Optional[NodeNG]:
if len(self.statements) > 0:
return self.statements[-1]

@property
def is_feasible(self) -> bool:
return any(edge.is_feasible for edge in self.predecessors)

def is_jump(self) -> bool:
"""Returns True if the block has a statement that branches
the control flow (ex: `break`)"""
Expand Down
20 changes: 20 additions & 0 deletions python_ta/checkers/inconsistent_or_missing_returns_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ class InconsistentReturnChecker(BaseChecker):
"Used when a function does not have a return statement and whose return type is not None",
),
}
options = (
(
"z3",
{
"default": False,
"type": "yn",
"metavar": "<y or n>",
"help": "Use Z3 to restrict control flow checks to paths that are logically feasible.",
},
),
)

def __init__(self, linter: Optional[PyLinter] = None) -> None:
super().__init__(linter=linter)
Expand Down Expand Up @@ -71,6 +82,15 @@ def _check_return_statements(self, node) -> None:
if has_return_annotation or has_return_value:
for block, statement in return_statements.items():
if statement is None:
# ignore unfeasible edges for missing return if z3 option is on
if self.linter.config.z3 and (
not block.is_feasible
or not any(
edge.is_feasible for edge in block.successors if edge.target is end
)
):
continue

# For rendering purpose:
# line: the line where the error occurs, used to calculate indentation
# end_line: the line to insert the error message
Expand Down
27 changes: 24 additions & 3 deletions python_ta/checkers/possibly_undefined_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ class PossiblyUndefinedChecker(BaseChecker):
"Reported when a statement uses a variable that might not be assigned.",
)
}
options = (
(
"z3",
{
"default": False,
"type": "yn",
"metavar": "<y or n>",
"help": "Use Z3 to restrict control flow checks to paths that are logically feasible.",
},
),
)

def __init__(self, linter=None) -> None:
super().__init__(linter=linter)
Expand Down Expand Up @@ -56,7 +67,7 @@ def _analyze(self, node: Union[nodes.Module, nodes.FunctionDef]) -> None:
out_facts = {}
cfg = ControlFlowGraph()
cfg.start = node.cfg_block
blocks = list(cfg.get_blocks_postorder())
blocks = list(cfg.get_blocks_postorder(only_feasible=self.linter.config.z3))
blocks.reverse()

all_assigns = self._get_assigns(node)
Expand All @@ -66,15 +77,25 @@ def _analyze(self, node: Union[nodes.Module, nodes.FunctionDef]) -> None:
worklist = blocks
while len(worklist) != 0:
b = worklist.pop()
outs = [out_facts[p.source] for p in b.predecessors if p.source in out_facts]
outs = [
out_facts[p.source]
for p in b.predecessors
if p.source in out_facts and (not self.linter.config.z3 or p.is_feasible)
]
if outs == []:
in_facts = set()
else:
in_facts = set.intersection(*outs)
temp = self._transfer(b, in_facts, all_assigns)
if temp != out_facts[b]:
out_facts[b] = temp
worklist.extend([succ.target for succ in b.successors])
worklist.extend(
[
succ.target
for succ in b.successors
if not self.linter.config.z3 or succ.is_feasible
]
)

def _transfer(self, block: CFGBlock, in_facts: set[str], local_vars: set[str]) -> set[str]:
gen = in_facts.copy()
Expand Down
27 changes: 24 additions & 3 deletions python_ta/checkers/redundant_assignment_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ class RedundantAssignmentChecker(BaseChecker):
" You can remove the assignment(s) without changing the behaviour of this code.",
)
}
options = (
(
"z3",
{
"default": False,
"type": "yn",
"metavar": "<y or n>",
"help": "Use Z3 to restrict control flow checks to paths that are logically feasible.",
},
),
)

def __init__(self, linter=None) -> None:
super().__init__(linter=linter)
Expand Down Expand Up @@ -88,7 +99,7 @@ def _analyze(self, node: Union[nodes.Module, nodes.FunctionDef]) -> None:
out_facts = {}
cfg = ControlFlowGraph()
cfg.start = node.cfg_block
worklist = list(cfg.get_blocks_postorder())
worklist = list(cfg.get_blocks_postorder(only_feasible=self.linter.config.z3))
worklist.reverse()

all_assigns = self._get_assigns(node)
Expand All @@ -97,15 +108,25 @@ def _analyze(self, node: Union[nodes.Module, nodes.FunctionDef]) -> None:

while len(worklist) != 0:
b = worklist.pop()
outs = [out_facts[p.target] for p in b.successors if p.target in out_facts]
outs = [
out_facts[p.target]
for p in b.successors
if p.target in out_facts and (not self.linter.config.z3 or p.is_feasible)
]
if outs == []:
in_facts = set()
else:
in_facts = set.intersection(*outs)
temp = self._transfer(b, in_facts)
if b in out_facts and temp != out_facts[b]:
out_facts[b] = temp
worklist.extend([pred.source for pred in b.predecessors if pred.source.reachable])
worklist.extend(
[
pred.source
for pred in b.predecessors
if pred.source.reachable and (not self.linter.config.z3 or pred.is_feasible)
]
)

def _transfer(self, block: CFGBlock, out_facts: set[str]) -> set[str]:
gen = out_facts.copy()
Expand Down
4 changes: 2 additions & 2 deletions python_ta/patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from .transforms import patch_ast_transforms


def patch_all(messages_config: dict):
def patch_all(messages_config: dict, z3: bool):
"""Execute all patches defined in this module."""
patch_checkers()
patch_ast_transforms()
patch_ast_transforms(z3)
patch_messages()
patch_error_messages(messages_config)
5 changes: 4 additions & 1 deletion python_ta/patches/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
from pylint.lint import PyLinter

from ..cfg.visitor import CFGVisitor
from ..transforms.z3_visitor import Z3Visitor


def patch_ast_transforms():
def patch_ast_transforms(z3: bool):
old_get_ast = PyLinter.get_ast

def new_get_ast(self, filepath, modname, data):
ast = old_get_ast(self, filepath, modname, data)
if ast is not None:
try:
if z3:
ast = Z3Visitor().visitor.visit(ast)
ast.accept(CFGVisitor())
except:
pass
Expand Down
84 changes: 84 additions & 0 deletions tests/test_custom_checkers/test_inconsistent_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from python_ta.checkers.inconsistent_or_missing_returns_checker import (
InconsistentReturnChecker,
)
from python_ta.transforms.z3_visitor import Z3Visitor


class TestInconsistentReturnChecker(pylint.testutils.CheckerTestCase):
Expand Down Expand Up @@ -153,3 +154,86 @@ def func():
ignore_position=True,
):
self.checker.visit_functiondef(func_node)


class TestInconsistentReturnCheckerZ3Option(pylint.testutils.CheckerTestCase):
CHECKER_CLASS = InconsistentReturnChecker
CONFIG = {"z3": True}

def test_z3_unfeasible_inconsistent_return(self):
src = """
def func(x: int) -> int:
'''
Preconditions:
- x > 5
'''
if x < 0:
return
return x
"""
z3v = Z3Visitor()
mod = z3v.visitor.visit(astroid.parse(src))
mod.accept(CFGVisitor())
func_node = next(mod.nodes_of_class(nodes.FunctionDef))
inconsistent_return_node, _ = mod.nodes_of_class(nodes.Return)

with self.assertAddsMessages(
pylint.testutils.MessageTest(
msg_id="inconsistent-returns",
node=inconsistent_return_node,
),
ignore_position=True,
):
self.checker.visit_functiondef(func_node)

def test_z3_partially_feasible_inconsistent_return(self):
src = """
def func(x: int) -> int:
'''
Preconditions:
- x > 5
'''
if x < 0:
print(x)
return
"""
z3v = Z3Visitor()
mod = z3v.visitor.visit(astroid.parse(src))
mod.accept(CFGVisitor())
func_node = next(mod.nodes_of_class(nodes.FunctionDef))
inconsistent_return_node = next(mod.nodes_of_class(nodes.Return))

with self.assertAddsMessages(
pylint.testutils.MessageTest(
msg_id="inconsistent-returns",
node=inconsistent_return_node,
),
ignore_position=True,
):
self.checker.visit_functiondef(func_node)

def test_z3_feasible_inconsistent_return(self):
src = """
def func(x: int) -> int:
'''
Preconditions:
- x > 5
'''
if x > 0:
return
return x
"""
z3v = Z3Visitor()
mod = z3v.visitor.visit(astroid.parse(src))
mod.accept(CFGVisitor())
func_node = next(mod.nodes_of_class(nodes.FunctionDef))
inconsistent_return_node, _ = mod.nodes_of_class(nodes.Return)

with self.assertAddsMessages(
pylint.testutils.MessageTest(
msg_id="inconsistent-returns",
node=inconsistent_return_node,
),
ignore_position=True,
):
self.checker.visit_functiondef(func_node)
Loading

0 comments on commit 866e31b

Please sign in to comment.