diff --git a/python/py_helpers.py b/python/py_helpers.py index f90277c..4aea876 100644 --- a/python/py_helpers.py +++ b/python/py_helpers.py @@ -66,8 +66,24 @@ def find_function(self, func): return Node(node) return Node() + def find_async_function(self, func): + if not self._has_body(): + return Node() + for node in self.tree.body: + if isinstance(node, ast.AsyncFunctionDef): + if node.name == func: + return Node(node) + return Node() + + def find_awaits(self): + return [ + node + for node in self._find_all(ast.Expr) + if isinstance(node.tree.value, ast.Await) + ] + def has_args(self, arg_str): - if not isinstance(self.tree, ast.FunctionDef): + if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)): return False if id := getattr(self.tree.returns, "id", False): returns = f"-> {id}" @@ -75,15 +91,17 @@ def has_args(self, arg_str): returns = f"-> '{val}'" else: returns = "" - + async_kw = "" + if isinstance(self.tree, ast.AsyncFunctionDef): + async_kw = "async " body_lines = str(self.find_body()).split("\n") new_body = "".join([f"\n {line}" for line in body_lines]) - func_str = f"def {self.tree.name}({arg_str}) {returns}:{new_body}" + func_str = f"{async_kw}def {self.tree.name}({arg_str}) {returns}:{new_body}" return self.is_equivalent(func_str) # returns_str is the annotation of the type returned by the function def has_returns(self, returns_str): - if not isinstance(self.tree, ast.FunctionDef): + if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)): return False if isinstance(self.tree.returns, ast.Name): return returns_str == self.tree.returns.id @@ -98,9 +116,76 @@ def find_body(self): return Node() return Node(ast.Module(self.tree.body, [])) + # find the return statement of a function + def find_return(self): + if return_list := self._find_all(ast.Return): + return return_list[0] + return Node() + + def has_return(self, return_value): + return self.find_return().is_equivalent(f"return {return_value}") + def find_imports(self): return self._find_all((ast.Import, ast.ImportFrom)) + def find_comps(self): + return [ + node + for node in self._find_all(ast.Expr) + if isinstance( + node.tree.value, + (ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp), + ) + ] + + def _find_comp( + self, classes=(ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp) + ): + if isinstance(self.tree, classes): + return Node(self.tree) + elif isinstance(self.tree, (ast.Assign, ast.AnnAssign, ast.Return)): + if isinstance(self.tree.value, classes): + return Node(self.tree.value) + return Node() + + # find a list of iterables of a comprehension/generator expression + def find_comp_iters(self): + if not (node := self._find_comp()): + return [] + return [Node(gen.iter) for gen in node.tree.generators] + + # find a list of targets (iteration variables) of a comprehension/generator expression + def find_comp_targets(self): + if not (node := self._find_comp()): + return [] + return [Node(gen.target) for gen in node.tree.generators] + + # find the key of a dictionary comprehension + def find_comp_key(self): + if not (node := self._find_comp(ast.DictComp)): + return Node() + return Node(node.tree.key) + + # find the expression evaluated for a comprehension/generator expression + # which is the value of the key in case of a dictionary comprehension + def find_comp_expr(self): + if not (node := self._find_comp()): + return Node() + if isinstance(node.tree, (ast.ListComp, ast.SetComp, ast.GeneratorExp)): + return Node(node.tree.elt) + elif isinstance(node.tree, ast.DictComp): + return Node(node.tree.value) + + # find a list of `IfExpr`s at the end of the comprehension/generator expression + def find_comp_ifs(self): + if not (node := self._find_comp()): + return [] + return [ + Node(gen.ifs[i]) + for gen in node.tree.generators + for i in range(len(gen.ifs)) + ] + # "has" functions return a boolean indicating whether whatever is being # searched for exists. In this case, it returns True if the variable exists. @@ -112,9 +197,30 @@ def has_import(self, import_str): import_node.is_equivalent(import_str) for import_node in self.find_imports() ) + # find a list of function calls of the 'name' function + def find_calls(self, name): + call_list = [] + for node in self._find_all(ast.Expr): + if func := getattr(node.tree.value, "func", False): + if isinstance(func, ast.Name) and func.id == name: + call_list.append(Node(node.tree.value)) + elif isinstance(func, ast.Attribute) and func.attr == name: + call_list.append(Node(node.tree.value)) + return call_list + def has_call(self, call): return any(node.is_equivalent(call) for node in self._find_all(ast.Expr)) + def find_call_args(self): + if not isinstance(self.tree, ast.Call): + return [] + return [Node(arg) for arg in self.tree.args] + + def has_stmt(self, node_str): + if not self._has_body(): + return False + return any(Node(node).is_equivalent(node_str) for node in self.tree.body) + def find_variable(self, name): if not self._has_body(): return Node() @@ -159,7 +265,7 @@ def has_class(self, name): return self.find_class(name) != Node() def has_decorators(self, *args): - if not isinstance(self.tree, ast.FunctionDef): + if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)): return False id_list = (node.id for node in self.tree.decorator_list) return all(arg in id_list for arg in args) diff --git a/python/py_helpers.test.py b/python/py_helpers.test.py index 1adacf7..1c04bf3 100644 --- a/python/py_helpers.test.py +++ b/python/py_helpers.test.py @@ -225,6 +225,31 @@ def foo(*, a, b, c=0): self.assertTrue(node.find_function("foo").has_args("*, a, b, c=0")) self.assertFalse(node.find_function("foo").has_args("*, a, b, c")) + def test_find_return(self): + code_str = """ +def foo(): + if x == 1: + return False + return True +""" + node = Node(code_str) + + self.assertTrue( + node.find_function("foo").find_return().is_equivalent("return True") + ) + + def test_has_return(self): + code_str = """ +def foo(): + if x == 1: + return False + return True +""" + node = Node(code_str) + + self.assertTrue(node.find_function("foo").has_return("True")) + self.assertTrue(node.find_function("foo").find_ifs()[0].has_return("False")) + def test_has_args_annotations(self): code_str = """ def foo(a: int, b: int) -> int: @@ -254,6 +279,51 @@ def foo(): self.assertFalse(node.find_function("foo").has_args("int")) + def test_find_calls(self): + code_str = """ +print(1) +int("1") +print(2) +foo("spam") +obj.foo("spam") +obj.bar.foo("spam") +""" + node = Node(code_str) + + self.assertEqual(len(node.find_calls("print")), 2) + self.assertTrue(node.find_calls("print")[0].is_equivalent("print(1)")) + self.assertTrue(node.find_calls("print")[1].is_equivalent("print(2)")) + self.assertEqual(len(node.find_calls("int")), 1) + self.assertTrue(node.find_calls("int")[0].is_equivalent("int('1')")) + self.assertEqual(len(node.find_calls("foo")), 3) + self.assertTrue(node.find_calls("foo")[0].is_equivalent("foo('spam')")) + self.assertTrue(node.find_calls("foo")[1].is_equivalent("obj.foo('spam')")) + self.assertTrue(node.find_calls("foo")[2].is_equivalent("obj.bar.foo('spam')")) + + def test_find_call_args(self): + code_str = """ +print(1) +print(2, 3) +obj.foo("spam") +""" + node = Node(code_str) + + self.assertEqual(len(node.find_calls("print")[0].find_call_args()), 1) + self.assertTrue( + node.find_calls("print")[0].find_call_args()[0].is_equivalent("1") + ) + self.assertEqual(len(node.find_calls("print")[1].find_call_args()), 2) + self.assertTrue( + node.find_calls("print")[1].find_call_args()[0].is_equivalent("2") + ) + self.assertTrue( + node.find_calls("print")[1].find_call_args()[1].is_equivalent("3") + ) + self.assertEqual(len(node.find_calls("foo")[0].find_call_args()), 1) + self.assertTrue( + node.find_calls("foo")[0].find_call_args()[0].is_equivalent("'spam'") + ) + def test_has_call(self): code_str = """ print(1) @@ -413,6 +483,81 @@ def bar(): ) +class TestAsyncHelpers(unittest.TestCase): + def test_find_async_function(self): + code_str = """ +async def foo(): + await bar() +""" + node = Node(code_str) + + self.assertTrue( + node.find_async_function("foo").is_equivalent( + "async def foo():\n await bar()" + ) + ) + + def test_find_async_function_args(self): + code_str = """ +async def foo(spam): + await bar() +""" + node = Node(code_str) + + self.assertTrue(node.find_async_function("foo").has_args("spam")) + + def test_find_async_function_return(self): + code_str = """ +async def foo(spam): + await bar() + return True +""" + node = Node(code_str) + + self.assertTrue(node.find_async_function("foo").has_return("True")) + + def test_find_async_function_returns(self): + code_str = """ +async def foo(spam) -> bool: + await bar() + return True +""" + node = Node(code_str) + + self.assertTrue(node.find_async_function("foo").has_returns("bool")) + + def test_find_awaits(self): + code_str = """ +async def foo(spam): + if spam: + await spam() + await bar() + await func() +""" + node = Node(code_str) + + self.assertEqual(len(node.find_async_function("foo").find_awaits()), 2) + self.assertTrue( + node.find_async_function("foo") + .find_awaits()[0] + .is_equivalent("await bar()") + ) + self.assertTrue( + node.find_async_function("foo") + .find_awaits()[1] + .is_equivalent("await func()") + ) + self.assertEqual( + len(node.find_async_function("foo").find_ifs()[0].find_awaits()), 1 + ) + self.assertTrue( + node.find_async_function("foo") + .find_ifs()[0] + .find_awaits()[0] + .is_equivalent("await spam()") + ) + + class TestEquivalenceHelpers(unittest.TestCase): def test_is_equivalent(self): full_str = """def foo(): @@ -1018,7 +1163,150 @@ def test_has_import(self): self.assertTrue(node.has_import("from py_helpers import Node as _Node")) +class TestComprehensionHelpers(unittest.TestCase): + def test_find_comps(self): + code_str = """ +[i**2 for i in lst] +(i for i in lst) +{i * j for i in spam for j in lst} +{k: v for k,v in dict} +""" + node = Node(code_str) + + self.assertEqual(len(node.find_comps()), 4) + self.assertTrue(node.find_comps()[0].is_equivalent("[i**2 for i in lst]")) + self.assertTrue(node.find_comps()[1].is_equivalent("(i for i in lst)")) + self.assertTrue( + node.find_comps()[2].is_equivalent("{i * j for i in spam for j in lst}") + ) + self.assertTrue(node.find_comps()[3].is_equivalent("{k: v for k,v in dict}")) + + def test_find_comp_iters(self): + code_str = """ +x = [i**2 for i in lst] + +def foo(spam): + return [i * j for i in spam for j in lst] +""" + node = Node(code_str) + + self.assertEqual(len(node.find_variable("x").find_comp_iters()), 1) + self.assertTrue( + node.find_variable("x").find_comp_iters()[0].is_equivalent("lst") + ) + self.assertEqual( + len(node.find_function("foo").find_return().find_comp_iters()), 2 + ) + self.assertTrue( + node.find_function("foo") + .find_return() + .find_comp_iters()[0] + .is_equivalent("spam") + ) + self.assertTrue( + node.find_function("foo") + .find_return() + .find_comp_iters()[1] + .is_equivalent("lst") + ) + + def test_find_comp_targets(self): + code_str = """ +x = [i**2 for i in lst] + +def foo(spam): + return [i * j for i in spam for j in lst] +""" + node = Node(code_str) + + self.assertEqual(len(node.find_variable("x").find_comp_targets()), 1) + self.assertTrue( + node.find_variable("x").find_comp_targets()[0].is_equivalent("i") + ) + self.assertEqual( + len(node.find_function("foo").find_return().find_comp_targets()), 2 + ) + self.assertTrue( + node.find_function("foo") + .find_return() + .find_comp_targets()[0] + .is_equivalent("i") + ) + self.assertTrue( + node.find_function("foo") + .find_return() + .find_comp_targets()[1] + .is_equivalent("j") + ) + + def test_find_comp_key(self): + code_str = """ +x = {k: v for k,v in dict} + +def foo(spam): + return {k: v for k in spam for v in lst} +""" + node = Node(code_str) + + self.assertTrue(node.find_variable("x").find_comp_key().is_equivalent("k")) + self.assertTrue( + node.find_function("foo").find_return().find_comp_key().is_equivalent("k") + ) + + def test_find_comp_expr(self): + code_str = """ +x = [i**2 if i else -1 for i in lst] + +def foo(spam): + return [i * j for i in spam for j in lst] +""" + node = Node(code_str) + + self.assertTrue( + node.find_variable("x").find_comp_expr().is_equivalent("i**2 if i else -1") + ) + self.assertTrue( + node.find_function("foo") + .find_return() + .find_comp_expr() + .is_equivalent("i*j") + ) + + def test_find_comp_ifs(self): + code_str = """ +x = [i**2 if i else -1 for i in lst] + +def foo(spam): + return [i * j for i in spam if i for j in lst if j] +""" + node = Node(code_str) + + self.assertEqual(len(node.find_variable("x").find_comp_ifs()), 0) + self.assertEqual( + len(node.find_function("foo").find_return().find_comp_ifs()), 2 + ) + self.assertTrue( + node.find_function("foo") + .find_return() + .find_comp_ifs()[0] + .is_equivalent("i") + ) + self.assertTrue( + node.find_function("foo") + .find_return() + .find_comp_ifs()[1] + .is_equivalent("j") + ) + + class TestGenericHelpers(unittest.TestCase): + def test_has_stmt(self): + self.assertTrue( + Node("name = input('hi')\nself.matrix[1][5] = 3").has_stmt( + "self.matrix[1][5] = 3" + ) + ) + def test_is_empty(self): self.assertTrue(Node().is_empty()) self.assertFalse(Node("x = 1").is_empty())