diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 75f37a06e..230d6a98e 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -846,6 +846,7 @@ def __init__(self, provider: "ScopeProvider") -> None: self.scope: Scope = GlobalScope() self.__deferred_accesses: List[DeferredAccess] = [] self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None] + self.__in___all___stack: List[bool] = [False] self.__in_annotation_stack: List[bool] = [False] self.__in_type_hint_stack: List[bool] = [False] self.__in_ignored_subscript: Set[cst.Subscript] = set() @@ -950,6 +951,23 @@ def _handle_string_annotation( self, node: Union[cst.SimpleString, cst.ConcatenatedString] ) -> bool: """Returns whether it successfully handled the string annotation""" + if self.__in___all___stack[-1]: + name = node.evaluated_value + + if isinstance(name, bytes): + name = name.decode("utf-8") + + access = Access( + node, + self.scope, + is_annotation=False, + is_type_hint=False, + ) + + list(self.scope.assignments[name])[0].record_access(access) + + return True + if ( self.__in_type_hint_stack[-1] or self.__in_annotation_stack[-1] ) and not self.__in_ignored_subscript: @@ -1091,6 +1109,22 @@ def visit_Nonlocal(self, node: cst.Nonlocal) -> Optional[bool]: self.scope.record_nonlocal_overwrite(name_item.name.value) return False + def is__all___assignment(self, node: cst.Assign) -> bool: + target = next( + (t.target for t in node.targets if isinstance(t.target, cst.Name)), None + ) + if target is None: + return False + return target.value == "__all__" + + def visit_Assign(self, node: cst.Assign) -> Optional[bool]: + if self.is__all___assignment(node): + self.__in___all___stack.append(True) + + def leave_Assign(self, node: cst.Assign) -> None: + if self.is__all___assignment(node): + self.__in___all___stack.pop() + def visit_ListComp(self, node: cst.ListComp) -> Optional[bool]: return self._visit_comp_alike(node) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index a2087645c..0ee13a3bd 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -2171,7 +2171,7 @@ def test_annotation_refers_to_nested_class(self) -> None: class Outer: class Nested: pass - + type Alias = Nested def meth1[T: Nested](self): pass @@ -2248,3 +2248,18 @@ def f[T: Inner](self): Inner f_scope = scopes[inner_in_func_body] self.assertIn(inner_in_func_body.value, f_scope.accesses) self.assertEqual(list(f_scope.accesses)[0].referents, set()) + + def test___all___assignment(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + import a + import b + + __all__ = ["a", "b"] + """ + ) + import_a_assignment = list(scopes[m]["a"])[0] + import_b_assignment = list(scopes[m]["b"])[0] + + self.assertEqual(len(import_a_assignment.references), 1) + self.assertEqual(len(import_b_assignment.references), 1)