Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: async/await, return, comprehension ast-based helpers #186

Merged
merged 14 commits into from
Apr 17, 2024
106 changes: 100 additions & 6 deletions python/py_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,46 @@ def find_function(self, func):
return Node(node)
return Node()

def find_async_function(self, func):
if not self._has_body():
return Node()
for node in self.tree.body:
if isinstance(node, ast.AsyncFunctionDef):
if node.name == func:
return Node(node)
return Node()
ojeytonwilliams marked this conversation as resolved.
Show resolved Hide resolved

def find_awaits(self):
return [
node
for node in self._find_all(ast.Expr)
if isinstance(node.tree.value, ast.Await)
]

def find_args(self):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return []

def has_args(self, arg_str):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
if id := getattr(self.tree.returns, "id", False):
returns = f"-> {id}"
elif val := getattr(self.tree.returns, "value", False):
returns = f"-> '{val}'"
else:
returns = ""

async_kw = ""
if isinstance(self.tree, ast.AsyncFunctionDef):
async_kw = "async "
body_lines = str(self.find_body()).split("\n")
new_body = "".join([f"\n {line}" for line in body_lines])
func_str = f"def {self.tree.name}({arg_str}) {returns}:{new_body}"
func_str = f"{async_kw}def {self.tree.name}({arg_str}) {returns}:{new_body}"
return self.is_equivalent(func_str)

# returns_str is the annotation of the type returned by the function
def has_returns(self, returns_str):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
if isinstance(self.tree.returns, ast.Name):
return returns_str == self.tree.returns.id
Expand All @@ -98,9 +120,69 @@ def find_body(self):
return Node()
return Node(ast.Module(self.tree.body, []))

# find the return statement of a function
def find_return(self):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return Node()
if return_list := self._find_all(ast.Return):
return return_list[0]
return Node()

def has_return(self, return_value):
return self.find_return().is_equivalent(f"return {return_value}")

def find_imports(self):
return self._find_all((ast.Import, ast.ImportFrom))

def _find_comp(
self, classes=(ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp)
):
if isinstance(self.tree, classes):
return Node(self.tree)
elif isinstance(self.tree, (ast.Assign, ast.AnnAssign, ast.Return)):
if isinstance(self.tree.value, classes):
return Node(self.tree.value)
return Node()
# elif isinstance(self.tree, ast.Expr):...

# find a list of iterables of a comprehension/generator expression
def find_comp_iters(self):
if not (node := self._find_comp()):
return []
return [Node(gen.iter) for gen in node.tree.generators]

# find a list of targets (iteration variables) of a comprehension/generator expression
def find_comp_targets(self):
if not (node := self._find_comp()):
return []
return [Node(gen.target) for gen in node.tree.generators]

# find the key of a dictionary comprehension
def find_comp_key(self):
if not (node := self._find_comp(ast.DictComp)):
return Node()
return Node(node.tree.key)

# find the expression evaluated for a comprehension/generator expression
# which is the value of the key in case of a dictionary comprehension
def find_comp_expr(self):
if not (node := self._find_comp()):
return Node()
if isinstance(node.tree, (ast.ListComp, ast.SetComp, ast.GeneratorExp)):
return Node(node.tree.elt)
elif isinstance(node.tree, ast.DictComp):
return Node(node.tree.value)

# find a list of `IfExpr`s at the end of the comprehension/generator expression
def find_comp_ifs(self):
if not (node := self._find_comp()):
return []
return [
Node(gen.ifs[i])
for gen in node.tree.generators
for i in range(len(gen.ifs))
]

# "has" functions return a boolean indicating whether whatever is being
# searched for exists. In this case, it returns True if the variable exists.

Expand All @@ -112,8 +194,20 @@ def has_import(self, import_str):
import_node.is_equivalent(import_str) for import_node in self.find_imports()
)

def find_call(self, call):
return [
Node(node.tree.value)
for node in self._find_all(ast.Expr)
if node.is_equivalent(call)
]

def has_call(self, call):
return any(node.is_equivalent(call) for node in self._find_all(ast.Expr))
return bool(self.find_call(call))

def find_call_args(self):
if not isinstance(self.tree, ast.Call):
return []
return [Node(arg) for arg in self.tree.args]

def find_variable(self, name):
if not self._has_body():
Expand Down Expand Up @@ -159,7 +253,7 @@ def has_class(self, name):
return self.find_class(name) != Node()

def has_decorators(self, *args):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
id_list = (node.id for node in self.tree.decorator_list)
return all(arg in id_list for arg in args)
Expand Down
218 changes: 218 additions & 0 deletions python/py_helpers.test.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,30 @@ def foo(*, a, b, c=0):
self.assertTrue(node.find_function("foo").has_args("*, a, b, c=0"))
self.assertFalse(node.find_function("foo").has_args("*, a, b, c"))

def test_find_return(self):
code_str = """
def foo():
if x == 1:
return False
return True
"""
node = Node(code_str)

self.assertTrue(
node.find_function("foo").find_return().is_equivalent("return True")
)

def test_has_return(self):
code_str = """
def foo():
if x == 1:
return False
return True
"""
node = Node(code_str)

self.assertTrue(node.find_function("foo").has_return("True"))
Dario-DC marked this conversation as resolved.
Show resolved Hide resolved

def test_has_args_annotations(self):
code_str = """
def foo(a: int, b: int) -> int:
Expand Down Expand Up @@ -413,6 +437,81 @@ def bar():
)


class TestAsyncHelpers(unittest.TestCase):
def test_find_async_function(self):
code_str = """
async def foo():
await bar()
"""
node = Node(code_str)

self.assertTrue(
node.find_async_function("foo").is_equivalent(
"async def foo():\n await bar()"
)
)

def test_find_async_function_args(self):
code_str = """
async def foo(spam):
await bar()
"""
node = Node(code_str)

self.assertTrue(node.find_async_function("foo").has_args("spam"))

def test_find_async_function_return(self):
code_str = """
async def foo(spam):
await bar()
return True
"""
node = Node(code_str)

self.assertTrue(node.find_async_function("foo").has_return("True"))

def test_find_async_function_returns(self):
code_str = """
async def foo(spam) -> bool:
await bar()
return True
"""
node = Node(code_str)

self.assertTrue(node.find_async_function("foo").has_returns("bool"))

def test_find_awaits(self):
code_str = """
async def foo(spam):
if spam:
await spam()
await bar()
await func()
"""
node = Node(code_str)

self.assertEqual(len(node.find_async_function("foo").find_awaits()), 2)
self.assertTrue(
node.find_async_function("foo")
.find_awaits()[0]
.is_equivalent("await bar()")
)
self.assertTrue(
node.find_async_function("foo")
.find_awaits()[1]
.is_equivalent("await func()")
)
self.assertEqual(
len(node.find_async_function("foo").find_ifs()[0].find_awaits()), 1
)
self.assertTrue(
node.find_async_function("foo")
.find_ifs()[0]
.find_awaits()[0]
.is_equivalent("await spam()")
)


class TestEquivalenceHelpers(unittest.TestCase):
def test_is_equivalent(self):
full_str = """def foo():
Expand Down Expand Up @@ -1018,6 +1117,125 @@ def test_has_import(self):
self.assertTrue(node.has_import("from py_helpers import Node as _Node"))


class TestComprehensionHelpers(unittest.TestCase):
def test_find_comp_iters(self):
code_str = """
x = [i**2 for i in lst]

def foo(spam):
return [i * j for i in spam for j in lst]
"""
node = Node(code_str)

self.assertEqual(len(node.find_variable("x").find_comp_iters()), 1)
self.assertTrue(
node.find_variable("x").find_comp_iters()[0].is_equivalent("lst")
)
self.assertEqual(
len(node.find_function("foo").find_return().find_comp_iters()), 2
)
self.assertTrue(
node.find_function("foo")
.find_return()
.find_comp_iters()[0]
.is_equivalent("spam")
)
self.assertTrue(
node.find_function("foo")
.find_return()
.find_comp_iters()[1]
.is_equivalent("lst")
)

def test_find_comp_targets(self):
code_str = """
x = [i**2 for i in lst]

def foo(spam):
return [i * j for i in spam for j in lst]
"""
node = Node(code_str)

self.assertEqual(len(node.find_variable("x").find_comp_targets()), 1)
self.assertTrue(
node.find_variable("x").find_comp_targets()[0].is_equivalent("i")
)
self.assertEqual(
len(node.find_function("foo").find_return().find_comp_targets()), 2
)
self.assertTrue(
node.find_function("foo")
.find_return()
.find_comp_targets()[0]
.is_equivalent("i")
)
self.assertTrue(
node.find_function("foo")
.find_return()
.find_comp_targets()[1]
.is_equivalent("j")
)

def test_find_comp_key(self):
code_str = """
x = {k: v for k,v in dict}

def foo(spam):
return {k: v for k in spam for v in lst}
"""
node = Node(code_str)

self.assertTrue(node.find_variable("x").find_comp_key().is_equivalent("k"))
self.assertTrue(
node.find_function("foo").find_return().find_comp_key().is_equivalent("k")
)

def test_find_comp_expr(self):
code_str = """
x = [i**2 if i else -1 for i in lst]

def foo(spam):
return [i * j for i in spam for j in lst]
"""
node = Node(code_str)

self.assertTrue(
node.find_variable("x").find_comp_expr().is_equivalent("i**2 if i else -1")
)
self.assertTrue(
node.find_function("foo")
.find_return()
.find_comp_expr()
.is_equivalent("i*j")
)

def test_find_comp_ifs(self):
code_str = """
x = [i**2 if i else -1 for i in lst]

def foo(spam):
return [i * j for i in spam if i for j in lst if j]
"""
node = Node(code_str)

self.assertEqual(len(node.find_variable("x").find_comp_ifs()), 0)
self.assertEqual(
len(node.find_function("foo").find_return().find_comp_ifs()), 2
)
self.assertTrue(
node.find_function("foo")
.find_return()
.find_comp_ifs()[0]
.is_equivalent("i")
)
self.assertTrue(
node.find_function("foo")
.find_return()
.find_comp_ifs()[1]
.is_equivalent("j")
)


class TestGenericHelpers(unittest.TestCase):
def test_is_empty(self):
self.assertTrue(Node().is_empty())
Expand Down
Loading