diff --git a/python_ta/checkers/top_level_code_checker.py b/python_ta/checkers/top_level_code_checker.py index bdd9ea5bf..4b86beb2c 100644 --- a/python_ta/checkers/top_level_code_checker.py +++ b/python_ta/checkers/top_level_code_checker.py @@ -29,7 +29,7 @@ def visit_module(self, node): or _is_constant_assignment(statement) or _is_main_block(statement) ): - self.add_message("forbidden-top-level-code", node=node, args=statement.lineno) + self.add_message("forbidden-top-level-code", node=statement, args=statement.lineno) # Helper functions @@ -51,9 +51,14 @@ def _is_constant_assignment(statement) -> bool: """ Return whether or not is a constant assignment. """ - return isinstance(statement, nodes.Assign) and re.match( - UpperCaseStyle.CONST_NAME_RGX, statement.targets[0].name - ) + if not isinstance(statement, nodes.Assign): + return False + + names = [] + for target in statement.targets: + names.extend(node.name for node in target.nodes_of_class(nodes.AssignName, nodes.Name)) + + return all(re.match(UpperCaseStyle.CONST_NAME_RGX, name) for name in names) def _is_main_block(statement) -> bool: @@ -62,7 +67,13 @@ def _is_main_block(statement) -> bool: """ return ( isinstance(statement, nodes.If) + and isinstance(statement.test, nodes.Compare) + and isinstance(statement.test.left, nodes.Name) + and isinstance(statement.test.left, nodes.Name) and statement.test.left.name == "__name__" + and len(statement.test.ops) == 1 + and statement.test.ops[0][0] == "==" + and isinstance(statement.test.ops[0][1], nodes.Const) and statement.test.ops[0][1].value == "__main__" ) diff --git a/tests/test_custom_checkers/test_top_level_code_checker.py b/tests/test_custom_checkers/test_top_level_code_checker.py index f6051a696..5fc2e1528 100644 --- a/tests/test_custom_checkers/test_top_level_code_checker.py +++ b/tests/test_custom_checkers/test_top_level_code_checker.py @@ -8,9 +8,6 @@ class TestTopLevelCodeChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = TopLevelCodeChecker CONFIG = {} - def setup(self): - self.setup_method() - def test_message_simple(self): """Top level code not allowed, raises a message.""" src = """ @@ -18,7 +15,9 @@ def test_message_simple(self): """ mod = astroid.parse(src) with self.assertAddsMessages( - pylint.testutils.MessageTest(msg_id="forbidden-top-level-code", node=mod, args=2), + pylint.testutils.MessageTest( + msg_id="forbidden-top-level-code", node=mod.body[0], args=2 + ), ignore_position=True, ): self.checker.visit_module(mod) @@ -32,7 +31,9 @@ def test_message_complex(self): """ mod = astroid.parse(src) with self.assertAddsMessages( - pylint.testutils.MessageTest(msg_id="forbidden-top-level-code", node=mod, args=4), + pylint.testutils.MessageTest( + msg_id="forbidden-top-level-code", node=mod.body[1], args=4 + ), ignore_position=True, ): self.checker.visit_module(mod) @@ -92,7 +93,37 @@ def test_message_regular_assignment(self): """ mod = astroid.parse(src) with self.assertAddsMessages( - pylint.testutils.MessageTest(msg_id="forbidden-top-level-code", node=mod, args=2), + pylint.testutils.MessageTest( + msg_id="forbidden-top-level-code", node=mod.body[0], args=2 + ), + ignore_position=True, + ): + self.checker.visit_module(mod) + + def test_message_regular_assignment_unpacking(self): + """Top level regular unpacking assignment not allowed, raises a message.""" + src = """ + name, CONST = "George", 3 + """ + mod = astroid.parse(src) + with self.assertAddsMessages( + pylint.testutils.MessageTest( + msg_id="forbidden-top-level-code", node=mod.body[0], args=2 + ), + ignore_position=True, + ): + self.checker.visit_module(mod) + + def test_message_regular_assignment_starred(self): + """Top level regular assignment with a starred target not allowed, raises a message.""" + src = """ + NAME, *nums = ["George", 3, 4] + """ + mod = astroid.parse(src) + with self.assertAddsMessages( + pylint.testutils.MessageTest( + msg_id="forbidden-top-level-code", node=mod.body[0], args=2 + ), ignore_position=True, ): self.checker.visit_module(mod)