Skip to content

Commit

Permalink
feat: async/await, return, comprehension ast-based helpers (#186)
Browse files Browse the repository at this point in the history
* feat: return and comprehension helpers

* feat: async/await helpers

* modified comprehension helpers

* fix: find_return

* feat: has_node

* delete comment

* feat: find_calls

* feat: find_comps

* fix: has_stmt

* expand find_calls test
  • Loading branch information
Dario-DC authored Apr 17, 2024
1 parent d0db232 commit 83383df
Show file tree
Hide file tree
Showing 2 changed files with 399 additions and 5 deletions.
116 changes: 111 additions & 5 deletions python/py_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,42 @@ def find_function(self, func):
return Node(node)
return Node()

def find_async_function(self, func):
if not self._has_body():
return Node()
for node in self.tree.body:
if isinstance(node, ast.AsyncFunctionDef):
if node.name == func:
return Node(node)
return Node()

def find_awaits(self):
return [
node
for node in self._find_all(ast.Expr)
if isinstance(node.tree.value, ast.Await)
]

def has_args(self, arg_str):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
if id := getattr(self.tree.returns, "id", False):
returns = f"-> {id}"
elif val := getattr(self.tree.returns, "value", False):
returns = f"-> '{val}'"
else:
returns = ""

async_kw = ""
if isinstance(self.tree, ast.AsyncFunctionDef):
async_kw = "async "
body_lines = str(self.find_body()).split("\n")
new_body = "".join([f"\n {line}" for line in body_lines])
func_str = f"def {self.tree.name}({arg_str}) {returns}:{new_body}"
func_str = f"{async_kw}def {self.tree.name}({arg_str}) {returns}:{new_body}"
return self.is_equivalent(func_str)

# returns_str is the annotation of the type returned by the function
def has_returns(self, returns_str):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
if isinstance(self.tree.returns, ast.Name):
return returns_str == self.tree.returns.id
Expand All @@ -98,9 +116,76 @@ def find_body(self):
return Node()
return Node(ast.Module(self.tree.body, []))

# find the return statement of a function
def find_return(self):
if return_list := self._find_all(ast.Return):
return return_list[0]
return Node()

def has_return(self, return_value):
return self.find_return().is_equivalent(f"return {return_value}")

def find_imports(self):
return self._find_all((ast.Import, ast.ImportFrom))

def find_comps(self):
return [
node
for node in self._find_all(ast.Expr)
if isinstance(
node.tree.value,
(ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp),
)
]

def _find_comp(
self, classes=(ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp)
):
if isinstance(self.tree, classes):
return Node(self.tree)
elif isinstance(self.tree, (ast.Assign, ast.AnnAssign, ast.Return)):
if isinstance(self.tree.value, classes):
return Node(self.tree.value)
return Node()

# find a list of iterables of a comprehension/generator expression
def find_comp_iters(self):
if not (node := self._find_comp()):
return []
return [Node(gen.iter) for gen in node.tree.generators]

# find a list of targets (iteration variables) of a comprehension/generator expression
def find_comp_targets(self):
if not (node := self._find_comp()):
return []
return [Node(gen.target) for gen in node.tree.generators]

# find the key of a dictionary comprehension
def find_comp_key(self):
if not (node := self._find_comp(ast.DictComp)):
return Node()
return Node(node.tree.key)

# find the expression evaluated for a comprehension/generator expression
# which is the value of the key in case of a dictionary comprehension
def find_comp_expr(self):
if not (node := self._find_comp()):
return Node()
if isinstance(node.tree, (ast.ListComp, ast.SetComp, ast.GeneratorExp)):
return Node(node.tree.elt)
elif isinstance(node.tree, ast.DictComp):
return Node(node.tree.value)

# find a list of `IfExpr`s at the end of the comprehension/generator expression
def find_comp_ifs(self):
if not (node := self._find_comp()):
return []
return [
Node(gen.ifs[i])
for gen in node.tree.generators
for i in range(len(gen.ifs))
]

# "has" functions return a boolean indicating whether whatever is being
# searched for exists. In this case, it returns True if the variable exists.

Expand All @@ -112,9 +197,30 @@ def has_import(self, import_str):
import_node.is_equivalent(import_str) for import_node in self.find_imports()
)

# find a list of function calls of the 'name' function
def find_calls(self, name):
call_list = []
for node in self._find_all(ast.Expr):
if func := getattr(node.tree.value, "func", False):
if isinstance(func, ast.Name) and func.id == name:
call_list.append(Node(node.tree.value))
elif isinstance(func, ast.Attribute) and func.attr == name:
call_list.append(Node(node.tree.value))
return call_list

def has_call(self, call):
return any(node.is_equivalent(call) for node in self._find_all(ast.Expr))

def find_call_args(self):
if not isinstance(self.tree, ast.Call):
return []
return [Node(arg) for arg in self.tree.args]

def has_stmt(self, node_str):
if not self._has_body():
return False
return any(Node(node).is_equivalent(node_str) for node in self.tree.body)

def find_variable(self, name):
if not self._has_body():
return Node()
Expand Down Expand Up @@ -159,7 +265,7 @@ def has_class(self, name):
return self.find_class(name) != Node()

def has_decorators(self, *args):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
id_list = (node.id for node in self.tree.decorator_list)
return all(arg in id_list for arg in args)
Expand Down
Loading

0 comments on commit 83383df

Please sign in to comment.