From fbf53559b4115e21bd0b56b734a396c5b87cf874 Mon Sep 17 00:00:00 2001 From: Marius Andra Date: Thu, 7 Dec 2023 13:51:02 +0100 Subject: [PATCH] fix(hogql): uuid type visitor (#19158) --- posthog/hogql/base.py | 12 ++++++++---- posthog/hogql/test/test_visitor.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/posthog/hogql/base.py b/posthog/hogql/base.py index 217ec2d3aeaec..234bce9c449af 100644 --- a/posthog/hogql/base.py +++ b/posthog/hogql/base.py @@ -16,11 +16,15 @@ class AST: start: Optional[int] = field(default=None) end: Optional[int] = field(default=None) + # This is part of the visitor pattern from visitor.py. def accept(self, visitor): - camel_case_name = camel_case_pattern.sub("_", self.__class__.__name__).lower() - if "hog_qlx" in camel_case_name: - camel_case_name = camel_case_name.replace("hog_qlx", "hogqlx_") - method_name = f"visit_{camel_case_name}" + name = camel_case_pattern.sub("_", self.__class__.__name__).lower() + + # NOTE: Sync with ./test/test_visitor.py#test_hogql_visitor_naming_exceptions + replacements = {"hog_qlxtag": "hogqlx_tag", "hog_qlxattribute": "hogqlx_attribute", "uuidtype": "uuid_type"} + for old, new in replacements.items(): + name = name.replace(old, new) + method_name = f"visit_{name}" if hasattr(visitor, method_name): visit = getattr(visitor, method_name) return visit(self) diff --git a/posthog/hogql/test/test_visitor.py b/posthog/hogql/test/test_visitor.py index 78b2d6dc42536..8aa6689328fbf 100644 --- a/posthog/hogql/test/test_visitor.py +++ b/posthog/hogql/test/test_visitor.py @@ -1,4 +1,5 @@ from posthog.hogql import ast +from posthog.hogql.ast import UUIDType, HogQLXTag, HogQLXAttribute from posthog.hogql.errors import HogQLException from posthog.hogql.parser import parse_expr from posthog.hogql.visitor import CloningVisitor, Visitor, TraversingVisitor @@ -137,3 +138,18 @@ def visit_constant(self, node: ast.Constant): self.assertEqual(str(e.exception), "You tried accessing a forbidden number, perish!") self.assertEqual(e.exception.start, 4) self.assertEqual(e.exception.end, 7) + + def test_hogql_visitor_naming_exceptions(self): + class NamingCheck(Visitor): + def visit_uuid_type(self, node: ast.Constant): + return "visit_uuid_type" + + def visit_hogqlx_tag(self, node: ast.Constant): + return "visit_hogqlx_tag" + + def visit_hogqlx_attribute(self, node: ast.Constant): + return "visit_hogqlx_attribute" + + assert NamingCheck().visit(UUIDType()) == "visit_uuid_type" + assert NamingCheck().visit(HogQLXAttribute(name="a", value="a")) == "visit_hogqlx_attribute" + assert NamingCheck().visit(HogQLXTag(kind="", attributes=[])) == "visit_hogqlx_tag"