Skip to content

Commit

Permalink
Implement basic autofix infrastructure and autofixer for TRIO100
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Mar 9, 2023
1 parent 881b16a commit 522c1db
Show file tree
Hide file tree
Showing 15 changed files with 371 additions and 34 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ repos:
rev: 0.10.9
hooks:
- id: shed
exclude: "tests/autofix_files/.*"
args: ['--py39-plus']

- repo: https://github.com/RobertCraigie/pyright-python
Expand Down Expand Up @@ -48,7 +49,7 @@ repos:
rev: 6.0.0
hooks:
- id: flake8
types: ["python", "pyi"]
types_or: ["python", "pyi"]
language_version: python3
additional_dependencies:
- flake8-builtins
Expand Down
17 changes: 13 additions & 4 deletions flake8_trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def main():
cwd=root,
).stdout.splitlines()
except (subprocess.SubprocessError, FileNotFoundError):
print("Doesn't seem to be a git repo; pass filenames to format.")
print(
"Doesn't seem to be a git repo; pass filenames to format.",
file=sys.stderr,
)
sys.exit(1)
all_filenames = [
os.path.join(root, f) for f in all_filenames if _should_format(f)
Expand All @@ -110,6 +113,9 @@ def main():
plugin = Plugin.from_filename(file)
for error in sorted(plugin.run()):
print(f"{file}:{error}")
if plugin.options.autofix:
with open(file, "w") as file:
file.write(plugin.module.code)


class Plugin:
Expand All @@ -122,7 +128,7 @@ def __init__(self, tree: ast.AST, lines: Sequence[str]):
self._tree = tree
source = "".join(lines)

self._module: cst.Module = cst_parse_module_native(source)
self.module: cst.Module = cst_parse_module_native(source)

@classmethod
def from_filename(cls, filename: str | PathLike[str]) -> Plugin: # pragma: no cover
Expand All @@ -137,12 +143,14 @@ def from_source(cls, source: str) -> Plugin:
plugin = Plugin.__new__(cls)
super(Plugin, plugin).__init__()
plugin._tree = ast.parse(source)
plugin._module = cst_parse_module_native(source)
plugin.module = cst_parse_module_native(source)
return plugin

def run(self) -> Iterable[Error]:
yield from Flake8TrioRunner.run(self._tree, self.options)
yield from Flake8TrioRunner_cst(self.options).run(self._module)
cst_runner = Flake8TrioRunner_cst(self.options, self.module)
yield from cst_runner.run()
self.module = cst_runner.module

@staticmethod
def add_options(option_manager: OptionManager | ArgumentParser):
Expand All @@ -157,6 +165,7 @@ def add_options(option_manager: OptionManager | ArgumentParser):
add_argument = functools.partial(
option_manager.add_option, parse_from_config=True
)
add_argument("--autofix", action="store_true", required=False)

add_argument(
"--no-checkpoint-warning-decorators",
Expand Down
8 changes: 4 additions & 4 deletions flake8_trio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,20 @@ def visit(self, node: ast.AST):


class Flake8TrioRunner_cst:
def __init__(self, options: Namespace):
def __init__(self, options: Namespace, module: Module):
super().__init__()
self.state = SharedState(options)
self.options = options
self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple(
v(self.state) for v in ERROR_CLASSES_CST if self.selected(v.error_codes)
)
self.module = module

def run(self, module: Module) -> Iterable[Error]:
def run(self) -> Iterable[Error]:
if not self.visitors:
return
wrapper = cst.MetadataWrapper(module)
for v in self.visitors:
_ = wrapper.visit(v)
self.module = cst.MetadataWrapper(self.module).visit(v)
yield from self.state.problems

def selected(self, error_codes: dict[str, str]) -> bool:
Expand Down
61 changes: 60 additions & 1 deletion flake8_trio/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ast
from fnmatch import fnmatch
from typing import TYPE_CHECKING, NamedTuple, TypeVar
from typing import TYPE_CHECKING, NamedTuple, TypeVar, cast

import libcst as cst
import libcst.matchers as m
Expand Down Expand Up @@ -341,3 +341,62 @@ def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool:
),
)
)


def get_comments(node: cst.CSTNode | Iterable[cst.CSTNode]) -> Iterator[cst.EmptyLine]:
# pyright can't use hasattr to narrow the type, so need a bunch of casts
if hasattr(node, "__iter__"):
for n in cast("Iterable[cst.CSTNode]", node):
yield from get_comments(n)
return
yield from (
cst.EmptyLine(comment=ensure_type(c, cst.Comment))
for c in m.findall(cast("cst.CSTNode", node), m.Comment())
)
return


# used in TRIO100
def flatten_preserving_comments(node: cst.BaseCompoundStatement):
# add leading lines (comments and empty lines) for the node to be removed
new_leading_lines = list(node.leading_lines)

# add other comments belonging to the node as empty lines with comments
for attr in "lpar", "items", "rpar":
# pragma, since this is currently only used to flatten `With` statements
if comment_nodes := getattr(node, attr, None): # pragma: no cover
new_leading_lines.extend(get_comments(comment_nodes))

# node.body is a BaseSuite, whose subclasses are SimpleStatementSuite
# and IndentedBlock
if isinstance(node.body, cst.SimpleStatementSuite):
# `with ...: pass;pass;pass` -> pass;pass;pass
return cst.SimpleStatementLine(node.body.body, leading_lines=new_leading_lines)

assert isinstance(node.body, cst.IndentedBlock)
nodes = list(node.body.body)

# nodes[0] is a BaseStatement, whose subclasses are SimpleStatementLine
# and BaseCompoundStatement - both of which has leading_lines
assert isinstance(nodes[0], (cst.SimpleStatementLine, cst.BaseCompoundStatement))

# add body header comment - i.e. comments on the same/last line of the statement
if node.body.header and node.body.header.comment:
new_leading_lines.append(
cst.EmptyLine(indent=True, comment=node.body.header.comment)
)
# add the leading lines of the first node
new_leading_lines.extend(nodes[0].leading_lines)
# update the first node with all the above constructed lines
nodes[0] = nodes[0].with_changes(leading_lines=new_leading_lines)

# if there's comments in the footer of the indented block, add a pass
# statement with the comments as leading lines
if node.body.footer:
nodes.append(
cst.SimpleStatementLine(
[cst.Pass()],
node.body.footer,
)
)
return cst.FlattenSentinel(nodes)
17 changes: 13 additions & 4 deletions flake8_trio/visitors/visitor100.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
import libcst.matchers as m

from .flake8triovisitor import Flake8TrioVisitor_cst
from .helpers import AttributeCall, error_class_cst, with_has_call
from .helpers import (
AttributeCall,
error_class_cst,
flatten_preserving_comments,
with_has_call,
)


@error_class_cst
Expand Down Expand Up @@ -46,12 +51,16 @@ def visit_With(self, node: cst.With) -> None:
else:
self.has_checkpoint_stack.append(True)

def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With:
def leave_With(
self, original_node: cst.With, updated_node: cst.With
) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]:
if not self.has_checkpoint_stack.pop():
for res in self.node_dict[original_node]:
self.error(res.node, res.base, res.function)
# if: autofixing is enabled for this code
# then: remove the with and pop out it's body

if self.options.autofix and len(updated_node.items) == 1:
return flatten_preserving_comments(updated_node)

return updated_node

def visit_For(self, node: cst.For):
Expand Down
2 changes: 1 addition & 1 deletion flake8_trio/visitors/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith):
nursery = get_matching_call(item.context_expr, "open_nursery")

# `isinstance(..., ast.Call)` is done in get_matching_call
body_call = cast(ast.Call, node.body[0].value)
body_call = cast("ast.Call", node.body[0].value)

if (
nursery is not None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.pyright]
strict = ["*.py", "tests/*.py", "flake8_trio/**/*.py"]
exclude = ["**/node_modules", "**/__pycache__", "**/.*"]
exclude = ["**/node_modules", "**/__pycache__", "**/.*", "tests/autofix_files/*"]
reportUnusedCallResult=false
reportUninitializedInstanceVariable=true
reportPropertyTypeMismatch=true
Expand Down
84 changes: 84 additions & 0 deletions tests/autofix_files/trio100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# type: ignore

import trio

# error: 5, "trio", "move_on_after"
...


async def function_name():
# fmt: off
...; ...; ...
# fmt: on
# error: 15, "trio", "fail_after"
...
# error: 15, "trio", "fail_at"
...
# error: 15, "trio", "move_on_after"
...
# error: 15, "trio", "move_on_at"
...
# error: 15, "trio", "CancelScope"
...

with trio.move_on_after(10):
await trio.sleep(1)

with trio.move_on_after(10):
await trio.sleep(1)
print("hello")

with trio.move_on_after(10):
while True:
await trio.sleep(1)
print("hello")

with open("filename") as _:
...

# error: 9, "trio", "fail_after"
...

send_channel, receive_channel = trio.open_memory_channel(0)
async with trio.fail_after(10):
async with send_channel:
...

async with trio.fail_after(10):
async for _ in receive_channel:
...

# error: 15, "trio", "fail_after"
for _ in receive_channel:
...

# fix missed alarm when function is defined inside the with scope
# error: 9, "trio", "move_on_after"

async def foo():
await trio.sleep(1)

# error: 9, "trio", "move_on_after"
if ...:

async def foo():
if ...:
await trio.sleep(1)

async with random_ignored_library.fail_after(10):
...


async def function_name2():
with (
open("") as _,
trio.fail_after(10), # error: 8, "trio", "fail_after"
):
...

with (
trio.fail_after(5), # error: 8, "trio", "fail_after"
open("") as _,
trio.move_on_after(5), # error: 8, "trio", "move_on_after"
):
...
59 changes: 59 additions & 0 deletions tests/autofix_files/trio100_simple_autofix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import trio

# a
# b
# error: 5, "trio", "move_on_after"
# c
# d
print(1) # e
# f
# g
print(2) # h
# i
# j
print(3) # k
# l
# m
pass
# n

# error: 5, "trio", "move_on_after"
...


# a
# b
# fmt: off
...;...;...
# fmt: on
# c
# d

# Doesn't autofix With's with multiple withitems
with (
trio.move_on_after(10), # error: 4, "trio", "move_on_after"
open("") as f,
):
...


# multiline with, despite only being one statement
# a
# b
# c
# error: 4, "trio", "move_on_after"
# d
# e
# f
# g
# h
# this comment is kept
...

# fmt: off
# a
# b
# error: 4, "trio", "move_on_after"
# c
...; ...; ...
# fmt: on
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def pytest_addoption(parser: pytest.Parser):
parser.addoption(
"--runfuzz", action="store_true", default=False, help="run fuzz tests"
)
parser.addoption(
"--generate-autofix",
action="store_true",
default=False,
help="generate autofix file content",
)
parser.addoption(
"--enable-visitor-codes-regex",
default=".*",
Expand All @@ -32,6 +38,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item
item.add_marker(skip_fuzz)


@pytest.fixture()
def generate_autofix(request: pytest.FixtureRequest):
return request.config.getoption("generate_autofix")


@pytest.fixture()
def enable_visitor_codes_regex(request: pytest.FixtureRequest):
return request.config.getoption("--enable-visitor-codes-regex")
Loading

0 comments on commit 522c1db

Please sign in to comment.