Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python 3.11 compatibility to 4.6.x #11362

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ Samuele Pedroni
Sankt Petersbug
Segev Finer
Serhii Mozghovyi
Shantanu Jain
Simon Gomizelj
Skylar Downes
Srinivas Reddy Thatiparthy
Expand Down
1 change: 1 addition & 0 deletions changelog/8539.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed assertion rewriting on Python 3.10.
1 change: 1 addition & 0 deletions changelog/9163.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The end line number and end column offset are now properly set for rewritten assert statements.
42 changes: 23 additions & 19 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,19 +575,12 @@ def _NameConstant(c):
return ast.Name(str(c), ast.Load())


def set_location(node, lineno, col_offset):
"""Set node location information recursively."""

def _fix(node, lineno, col_offset):
if "lineno" in node._attributes:
node.lineno = lineno
if "col_offset" in node._attributes:
node.col_offset = col_offset
for child in ast.iter_child_nodes(node):
_fix(child, lineno, col_offset)

_fix(node, lineno, col_offset)
return node
def traverse_node(node):
"""Recursively yield node and all its children in depth-first order."""
yield node
for child in ast.iter_child_nodes(node):
for descendant in traverse_node(child):
yield descendant


class AssertionRewriter(ast.NodeVisitor):
Expand Down Expand Up @@ -652,12 +645,9 @@ def run(self, mod):
if not mod.body:
# Nothing to do.
return

# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [
ast.alias(six.moves.builtins.__name__, "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
]
doc = getattr(mod, "docstring", None)
expect_docstring = doc is None
if doc is not None and self.is_rewrite_disabled(doc):
Expand All @@ -684,6 +674,19 @@ def run(self, mod):
pos += 1
else:
lineno = item.lineno
if sys.version_info >= (3, 10):
aliases = [
ast.alias(six.moves.builtins.__name__, "@py_builtins", lineno=lineno, col_offset=0),
ast.alias(
"_pytest.assertion.rewrite", "@pytest_ar",
lineno=lineno, col_offset=0
),
]
else:
aliases = [
ast.alias(six.moves.builtins.__name__, "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
]
imports = [
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
]
Expand Down Expand Up @@ -858,9 +861,10 @@ def visit_Assert(self, assert_):
variables = [ast.Name(name, ast.Store()) for name in self.variables]
clear = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear)
# Fix line numbers.
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset)
for node in traverse_node(stmt):
ast.copy_location(node, assert_)
return self.statements

def warn_about_none_ast(self, node, module_path, lineno):
Expand Down
22 changes: 22 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,28 @@ def test_place_initial_imports(self):
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr)

def test_location_is_set(self):
s = textwrap.dedent(
"""

assert False, (

"Ouch"
)

"""
)
m = rewrite(s)
for node in m.body:
if isinstance(node, ast.Import):
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert n.lineno == 3
assert n.col_offset == 0
if sys.version_info >= (3, 8):
assert n.end_lineno == 6
assert n.end_col_offset == 3

def test_dont_rewrite(self):
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
m = rewrite(s)
Expand Down
Loading