Skip to content

Commit

Permalink
refactor: Chainable -> Node (#122)
Browse files Browse the repository at this point in the history
* feat: construct with string or AST

* fix: always construct with AST or string

* refactor: Chainable -> ASTExplorer

* feat: remove parse()

* test: improve test coverage

* fix: handle if_bodies on non-If node

* test: improve coverage

* test: more precise test

* test: more demostrations of is_equivalent

* refactor: ASTExplorer -> Node
  • Loading branch information
ojeytonwilliams authored Jan 24, 2024
1 parent 6374cc2 commit 4e91a0e
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 160 deletions.
74 changes: 40 additions & 34 deletions python/py_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
# A chainable class that allows us to call functions on the result of parsing a string


class Chainable:
class Node:
# TODO: allow initialization with a string
def __init__(self, tree=None):
self.tree = tree
if isinstance(tree, str):
self.tree = ast.parse(tree)
elif isinstance(tree, ast.AST) or tree == None:
self.tree = tree
else:
raise TypeError("Node must be initialized with a string or AST")

def __getitem__(self, i):
if getattr(self.tree, "__getitem__", False):
return Chainable(self.tree[i])
return Node(self.tree[i])
else:
return Chainable(self.tree.body[i])
return Node(self.tree.body[i])

def __len__(self):
if getattr(self.tree, "__len__", False):
Expand All @@ -21,7 +26,7 @@ def __len__(self):
return len(self.tree.body)

def __eq__(self, other):
if not isinstance(other, Chainable):
if not isinstance(other, Node):
return False
if self.tree == None:
return other.tree == None
Expand All @@ -33,54 +38,51 @@ def __eq__(self, other):

def __repr__(self):
if self.tree == None:
return "Chainable:\nNone"
return "Chainable:\n" + ast.dump(self.tree, indent=2)

def parse(self, string):
return Chainable(ast.parse(string))
return "Node:\nNone"
return "Node:\n" + ast.dump(self.tree, indent=2)

def _has_body(self):
return bool(getattr(self.tree, "body", False))

# "find" functions return a new chainable with the result of the find
# function. In this case, it returns a new chainable with the function
# "find" functions return a new node with the result of the find
# function. In this case, it returns a new node with the function
# definition (if it exists)

def find_function(self, func):
if not self._has_body():
return Chainable()
return Node()
for node in self.tree.body:
if isinstance(node, ast.FunctionDef):
if node.name == func:
return Chainable(node)
return Chainable()
return Node(node)
return Node()

# "has" functions return a boolean indicating whether whatever is being
# searched for exists. In this case, it returns True if the variable exists.

def has_variable(self, name):
return self.find_variable(name) != Chainable()
return self.find_variable(name) != Node()

def find_variable(self, name):
if not self._has_body():
return Chainable()
return Node()
for node in self.tree.body:
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name):
if target.id == name:
return Chainable(node)
return Chainable()
return Node(node)
return Node()

def get_variable(self, name):
var = self.find_variable(name)
if var != Chainable():
if var != Node():
return var.tree.value.value
else:
return None

def has_function(self, name):
return self.find_function(name) != Chainable()
return self.find_function(name) != Node()

# Checks the variable, name, is in the current scope and is an integer

Expand All @@ -97,7 +99,7 @@ def value_is_call(self, name):
return call.func.id == name
return False

# Takes an string and checks if is equivalent to the chainable's AST. This
# Takes an string and checks if is equivalent to the node's AST. This
# is a loose comparison that tries to find out if the code is essentially
# the same. For example, the string "True" is not represented by the same
# AST as the test in "if True:" (the string could be wrapped in Module,
Expand All @@ -116,12 +118,12 @@ def is_equivalent(self, target_str):

def find_class(self, class_name):
if not self._has_body():
return Chainable()
return Node()
for node in self.tree.body:
if isinstance(node, ast.ClassDef):
if node.name == class_name:
return Chainable(node)
return Chainable()
return Node(node)
return Node()

# Find an array of conditions in an if statement

Expand All @@ -130,7 +132,7 @@ def find_ifs(self):

def _find_all(self, ast_type):
return [
Chainable(node) for node in self.tree.body if isinstance(node, ast_type)
Node(node) for node in self.tree.body if isinstance(node, ast_type)
]

def find_conditions(self):
Expand All @@ -140,22 +142,26 @@ def _find_conditions(tree):
test = tree.test
if self.tree.orelse == []:
return [test]
elif isinstance(tree.orelse[0], ast.If):
if isinstance(tree.orelse[0], ast.If):
return [test] + _find_conditions(tree.orelse[0])
else:
return [test, None]

return [Chainable(test) for test in _find_conditions(self.tree)]
return [test, None]

return [Node(test) for test in _find_conditions(self.tree)]

# Find an array of bodies in an elif statement

def find_if_bodies(self):
def _find_if_bodies(tree):
if not isinstance(tree, ast.If):
return []
if self.tree.orelse == []:
return [tree.body]
elif isinstance(tree.orelse[0], ast.If):
if isinstance(tree.orelse[0], ast.If):
return [tree.body] + _find_if_bodies(tree.orelse[0])
else:
return [tree.body] + [tree.orelse]

return [Chainable(body) for body in _find_if_bodies(self.tree)]
return [tree.body] + [tree.orelse]

return [
Node(ast.Module(body, [])) for body in _find_if_bodies(self.tree)
]
Loading

0 comments on commit 4e91a0e

Please sign in to comment.