Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: async/await, return, comprehension ast-based helpers #186

Merged
merged 14 commits into from
Apr 17, 2024
114 changes: 109 additions & 5 deletions python/py_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,42 @@ def find_function(self, func):
return Node(node)
return Node()

def find_async_function(self, func):
if not self._has_body():
return Node()
for node in self.tree.body:
if isinstance(node, ast.AsyncFunctionDef):
if node.name == func:
return Node(node)
return Node()
ojeytonwilliams marked this conversation as resolved.
Show resolved Hide resolved

def find_awaits(self):
return [
node
for node in self._find_all(ast.Expr)
if isinstance(node.tree.value, ast.Await)
]

def has_args(self, arg_str):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
if id := getattr(self.tree.returns, "id", False):
returns = f"-> {id}"
elif val := getattr(self.tree.returns, "value", False):
returns = f"-> '{val}'"
else:
returns = ""

async_kw = ""
if isinstance(self.tree, ast.AsyncFunctionDef):
async_kw = "async "
body_lines = str(self.find_body()).split("\n")
new_body = "".join([f"\n {line}" for line in body_lines])
func_str = f"def {self.tree.name}({arg_str}) {returns}:{new_body}"
func_str = f"{async_kw}def {self.tree.name}({arg_str}) {returns}:{new_body}"
return self.is_equivalent(func_str)

# returns_str is the annotation of the type returned by the function
def has_returns(self, returns_str):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
if isinstance(self.tree.returns, ast.Name):
return returns_str == self.tree.returns.id
Expand All @@ -98,9 +116,76 @@ def find_body(self):
return Node()
return Node(ast.Module(self.tree.body, []))

# find the return statement of a function
def find_return(self):
if return_list := self._find_all(ast.Return):
return return_list[0]
return Node()

def has_return(self, return_value):
return self.find_return().is_equivalent(f"return {return_value}")

def find_imports(self):
return self._find_all((ast.Import, ast.ImportFrom))

def find_comps(self):
return [
node
for node in self._find_all(ast.Expr)
if isinstance(
node.tree.value,
(ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp),
)
]

def _find_comp(
self, classes=(ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp)
):
if isinstance(self.tree, classes):
return Node(self.tree)
elif isinstance(self.tree, (ast.Assign, ast.AnnAssign, ast.Return)):
if isinstance(self.tree.value, classes):
return Node(self.tree.value)
return Node()

# find a list of iterables of a comprehension/generator expression
def find_comp_iters(self):
if not (node := self._find_comp()):
return []
return [Node(gen.iter) for gen in node.tree.generators]

# find a list of targets (iteration variables) of a comprehension/generator expression
def find_comp_targets(self):
if not (node := self._find_comp()):
return []
return [Node(gen.target) for gen in node.tree.generators]

# find the key of a dictionary comprehension
def find_comp_key(self):
if not (node := self._find_comp(ast.DictComp)):
return Node()
return Node(node.tree.key)

# find the expression evaluated for a comprehension/generator expression
# which is the value of the key in case of a dictionary comprehension
def find_comp_expr(self):
if not (node := self._find_comp()):
return Node()
if isinstance(node.tree, (ast.ListComp, ast.SetComp, ast.GeneratorExp)):
return Node(node.tree.elt)
elif isinstance(node.tree, ast.DictComp):
return Node(node.tree.value)

# find a list of `IfExpr`s at the end of the comprehension/generator expression
def find_comp_ifs(self):
if not (node := self._find_comp()):
return []
return [
Node(gen.ifs[i])
for gen in node.tree.generators
for i in range(len(gen.ifs))
]

# "has" functions return a boolean indicating whether whatever is being
# searched for exists. In this case, it returns True if the variable exists.

Expand All @@ -112,9 +197,28 @@ def has_import(self, import_str):
import_node.is_equivalent(import_str) for import_node in self.find_imports()
)

# find a list of function calls of the 'name' function
def find_calls(self, name):
call_list = []
for node in self._find_all(ast.Expr):
if func := getattr(node.tree.value, "func", False):
if isinstance(func, ast.Name) and func.id == name:
call_list.append(Node(node.tree.value))
elif isinstance(func, ast.Attribute) and func.attr == name:
call_list.append(Node(node.tree.value))
return call_list

def has_call(self, call):
return any(node.is_equivalent(call) for node in self._find_all(ast.Expr))

def find_call_args(self):
if not isinstance(self.tree, ast.Call):
return []
return [Node(arg) for arg in self.tree.args]

def has_node(self, node_str):
Dario-DC marked this conversation as resolved.
Show resolved Hide resolved
return any(Node(node).is_equivalent(node_str) for node in self.tree.body)
Dario-DC marked this conversation as resolved.
Show resolved Hide resolved

def find_variable(self, name):
if not self._has_body():
return Node()
Expand Down Expand Up @@ -159,7 +263,7 @@ def has_class(self, name):
return self.find_class(name) != Node()

def has_decorators(self, *args):
if not isinstance(self.tree, ast.FunctionDef):
if not isinstance(self.tree, (ast.FunctionDef, ast.AsyncFunctionDef)):
return False
id_list = (node.id for node in self.tree.decorator_list)
return all(arg in id_list for arg in args)
Expand Down
Loading
Loading