Skip to content

Commit

Permalink
feat: add find_variables ast-helper (#212)
Browse files Browse the repository at this point in the history
* feat: find_variables

* fix has_returns

* add test has_returns
  • Loading branch information
Dario-DC authored Jun 20, 2024
1 parent fb25066 commit 8b3181f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
21 changes: 21 additions & 0 deletions python/py_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def has_returns(self, returns_str):
return returns_str == self.tree.returns.id
elif isinstance(self.tree.returns, ast.Constant):
return returns_str == self.tree.returns.value
elif isinstance((ann := self.tree.returns), ast.Subscript):
return Node(ann).is_equivalent(returns_str)
return False

def find_body(self):
Expand Down Expand Up @@ -251,6 +253,25 @@ def find_variable(self, name):
return Node(node)
return Node()

def find_variables(self, name):
assignments = self._find_all((ast.Assign, ast.AnnAssign))
var_list = []
for node in assignments:
if isinstance(node.tree, ast.Assign):
for target in node.tree.targets:
if isinstance(target, ast.Name):
if target.id == name:
var_list.append(node)
if isinstance(target, ast.Attribute):
names = name.split(".")
if target.value.id == names[0] and target.attr == names[1]:
var_list.append(node)
elif isinstance(node.tree, ast.AnnAssign):
if isinstance(node.tree.target, ast.Name):
if node.tree.target.id == name:
var_list.append(node)
return var_list

# find variable incremented or decremented using += or -=
def find_aug_variable(self, name):
if not self._has_body():
Expand Down
23 changes: 22 additions & 1 deletion python/py_helpers.test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ def foo():
)
self.assertEqual(node.find_function("foo").find_aug_variable("x"), Node())

def test_find_variables(self):
code_str = """
x: int = 0
a.b = 0
x = 5
a.b = 2
x = 10
"""
node = Node(code_str)
self.assertEqual(len(node.find_variables("x")), 3)
self.assertTrue(node.find_variables("x")[0].is_equivalent("x: int = 0"))
self.assertTrue(node.find_variables("x")[1].is_equivalent("x = 5"))
self.assertTrue(node.find_variables("x")[2].is_equivalent("x = 10"))
self.assertEqual(len(node.find_variables("a.b")), 2)
self.assertTrue(node.find_variables("a.b")[0].is_equivalent("a.b = 0"))
self.assertTrue(node.find_variables("a.b")[1].is_equivalent("a.b = 2"))


class TestFunctionAndClassHelpers(unittest.TestCase):
def test_find_function_returns_node(self):
Expand Down Expand Up @@ -295,12 +312,16 @@ def foo(a: int, b: int) -> int:
def test_has_returns(self):
code_str = """
def foo() -> int:
pass
pass
def spam() -> Dict[str, int]:
pass
"""
node = Node(code_str)

self.assertTrue(node.find_function("foo").has_returns("int"))
self.assertFalse(node.find_function("foo").has_returns("str"))
self.assertTrue(node.find_function("spam").has_returns("Dict[str, int]"))

def test_has_returns_without_returns(self):
code_str = """
Expand Down

0 comments on commit 8b3181f

Please sign in to comment.