Skip to content

Commit

Permalink
Implement basic autofix infrastructure and autofixer for TRIO100
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Mar 8, 2023
1 parent 881b16a commit e7e4a95
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 31 deletions.
17 changes: 13 additions & 4 deletions flake8_trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def main():
cwd=root,
).stdout.splitlines()
except (subprocess.SubprocessError, FileNotFoundError):
print("Doesn't seem to be a git repo; pass filenames to format.")
print(
"Doesn't seem to be a git repo; pass filenames to format.",
file=sys.stderr,
)
sys.exit(1)
all_filenames = [
os.path.join(root, f) for f in all_filenames if _should_format(f)
Expand All @@ -110,6 +113,9 @@ def main():
plugin = Plugin.from_filename(file)
for error in sorted(plugin.run()):
print(f"{file}:{error}")
if plugin.options.autofix:
with open(file, "w") as file:
file.write(plugin.module.code)


class Plugin:
Expand All @@ -122,7 +128,7 @@ def __init__(self, tree: ast.AST, lines: Sequence[str]):
self._tree = tree
source = "".join(lines)

self._module: cst.Module = cst_parse_module_native(source)
self.module: cst.Module = cst_parse_module_native(source)

@classmethod
def from_filename(cls, filename: str | PathLike[str]) -> Plugin: # pragma: no cover
Expand All @@ -137,12 +143,14 @@ def from_source(cls, source: str) -> Plugin:
plugin = Plugin.__new__(cls)
super(Plugin, plugin).__init__()
plugin._tree = ast.parse(source)
plugin._module = cst_parse_module_native(source)
plugin.module = cst_parse_module_native(source)
return plugin

def run(self) -> Iterable[Error]:
yield from Flake8TrioRunner.run(self._tree, self.options)
yield from Flake8TrioRunner_cst(self.options).run(self._module)
cst_runner = Flake8TrioRunner_cst(self.options, self.module)
yield from cst_runner.run()
self.module = cst_runner.module

@staticmethod
def add_options(option_manager: OptionManager | ArgumentParser):
Expand All @@ -157,6 +165,7 @@ def add_options(option_manager: OptionManager | ArgumentParser):
add_argument = functools.partial(
option_manager.add_option, parse_from_config=True
)
add_argument("--autofix", action="store_true", required=False)

add_argument(
"--no-checkpoint-warning-decorators",
Expand Down
8 changes: 4 additions & 4 deletions flake8_trio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,20 @@ def visit(self, node: ast.AST):


class Flake8TrioRunner_cst:
def __init__(self, options: Namespace):
def __init__(self, options: Namespace, module: Module):
super().__init__()
self.state = SharedState(options)
self.options = options
self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple(
v(self.state) for v in ERROR_CLASSES_CST if self.selected(v.error_codes)
)
self.module = module

def run(self, module: Module) -> Iterable[Error]:
def run(self) -> Iterable[Error]:
if not self.visitors:
return
wrapper = cst.MetadataWrapper(module)
for v in self.visitors:
_ = wrapper.visit(v)
self.module = cst.MetadataWrapper(self.module).visit(v)
yield from self.state.problems

def selected(self, error_codes: dict[str, str]) -> bool:
Expand Down
35 changes: 35 additions & 0 deletions flake8_trio/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,38 @@ def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool:
),
)
)


# used in TRIO100
def flatten_preserving_comments(node: cst.BaseCompoundStatement):
if isinstance(node.body, cst.SimpleStatementSuite):
# `with ...: pass;pass;pass` -> pass;pass;pass
return cst.SimpleStatementLine(node.body.body, leading_lines=node.leading_lines)

assert isinstance(node.body, cst.IndentedBlock)
nodes = list(node.body.body)

# nodes[0] is a BaseStatement, whose subclasses are SimpleStatementLine
# and BaseCompoundStatement - both of which has leading_lines
assert isinstance(nodes[0], (cst.SimpleStatementLine, cst.BaseCompoundStatement))

# add leading lines of the original node to the leading lines
# of the first statement in the body
new_leading_lines = list(node.leading_lines)
if node.body.header and node.body.header.comment:
new_leading_lines.append(
cst.EmptyLine(indent=True, comment=node.body.header.comment)
)
new_leading_lines.extend(nodes[0].leading_lines)
nodes[0] = nodes[0].with_changes(leading_lines=new_leading_lines)

# if there's comments in the footer of the indented block, add a pass
# statement with the comments as leading lines
if node.body.footer:
nodes.append(
cst.SimpleStatementLine(
[cst.Pass()],
node.body.footer,
)
)
return cst.FlattenSentinel(nodes)
22 changes: 17 additions & 5 deletions flake8_trio/visitors/visitor100.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
"""
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

import libcst as cst
import libcst.matchers as m

from .flake8triovisitor import Flake8TrioVisitor_cst
from .helpers import AttributeCall, error_class_cst, with_has_call
from .helpers import (
AttributeCall,
error_class_cst,
flatten_preserving_comments,
with_has_call,
)

if TYPE_CHECKING:
pass


@error_class_cst
Expand Down Expand Up @@ -46,12 +54,16 @@ def visit_With(self, node: cst.With) -> None:
else:
self.has_checkpoint_stack.append(True)

def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With:
def leave_With(
self, original_node: cst.With, updated_node: cst.With
) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]:
if not self.has_checkpoint_stack.pop():
for res in self.node_dict[original_node]:
self.error(res.node, res.base, res.function)
# if: autofixing is enabled for this code
# then: remove the with and pop out it's body

if self.options.autofix and len(updated_node.items) == 1:
return flatten_preserving_comments(updated_node)

return updated_node

def visit_For(self, node: cst.For):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.pyright]
strict = ["*.py", "tests/*.py", "flake8_trio/**/*.py"]
exclude = ["**/node_modules", "**/__pycache__", "**/.*"]
exclude = ["**/node_modules", "**/__pycache__", "**/.*", "tests/autofix_files/*"]
reportUnusedCallResult=false
reportUninitializedInstanceVariable=true
reportPropertyTypeMismatch=true
Expand Down
84 changes: 84 additions & 0 deletions tests/autofix_files/trio100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# type: ignore

import trio

# error: 5, "trio", "move_on_after"
...


async def function_name():
# fmt: off
...; ...; ...
# fmt: on
# error: 15, "trio", "fail_after"
...
# error: 15, "trio", "fail_at"
...
# error: 15, "trio", "move_on_after"
...
# error: 15, "trio", "move_on_at"
...
# error: 15, "trio", "CancelScope"
...

with trio.move_on_after(10):
await trio.sleep(1)

with trio.move_on_after(10):
await trio.sleep(1)
print("hello")

with trio.move_on_after(10):
while True:
await trio.sleep(1)
print("hello")

with open("filename") as _:
...

# error: 9, "trio", "fail_after"
...

send_channel, receive_channel = trio.open_memory_channel(0)
async with trio.fail_after(10):
async with send_channel:
...

async with trio.fail_after(10):
async for _ in receive_channel:
...

# error: 15, "trio", "fail_after"
for _ in receive_channel:
...

# fix missed alarm when function is defined inside the with scope
# error: 9, "trio", "move_on_after"

async def foo():
await trio.sleep(1)

# error: 9, "trio", "move_on_after"
if ...:

async def foo():
if ...:
await trio.sleep(1)

async with random_ignored_library.fail_after(10):
...


async def function_name2():
with (
open("") as _,
trio.fail_after(10), # error: 8, "trio", "fail_after"
):
...

with (
trio.fail_after(5), # error: 8, "trio", "fail_after"
open("") as _,
trio.move_on_after(5), # error: 8, "trio", "move_on_after"
):
...
43 changes: 43 additions & 0 deletions tests/autofix_files/trio100_simple_autofix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import trio

# a
# b
# error: 5, "trio", "move_on_after"
# c
# d
print(1) # e
# f
# g
print(2) # h
# i
# j
print(3) # k
# l
# m
pass
# n

# error: 5, "trio", "move_on_after"
...


# a
# b
# fmt: off
...;...;...
# fmt: on
# c
# d

# Doesn't autofix With's with multiple withitems
with (
trio.move_on_after(10), # error: 4, "trio", "move_on_after"
open("") as f,
):
...


# extreme case I'm not gonna care about, i.e. one item in the with, but it's multiline.
# Only these leading comments, and the last one, are kept, the rest are lost.
# this comment is kept
...
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def pytest_addoption(parser: pytest.Parser):
parser.addoption(
"--runfuzz", action="store_true", default=False, help="run fuzz tests"
)
parser.addoption(
"--generate-autofix",
action="store_true",
default=False,
help="generate autofix file content",
)
parser.addoption(
"--enable-visitor-codes-regex",
default=".*",
Expand All @@ -32,6 +38,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item
item.add_marker(skip_fuzz)


@pytest.fixture()
def generate_autofix(request: pytest.FixtureRequest):
return request.config.getoption("generate_autofix")


@pytest.fixture()
def enable_visitor_codes_regex(request: pytest.FixtureRequest):
return request.config.getoption("--enable-visitor-codes-regex")
Loading

0 comments on commit e7e4a95

Please sign in to comment.