Skip to content

Commit

Permalink
fix(hogql): uuid type visitor (#19158)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra authored Dec 7, 2023
1 parent ecb111e commit fbf5355
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
12 changes: 8 additions & 4 deletions posthog/hogql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ class AST:
start: Optional[int] = field(default=None)
end: Optional[int] = field(default=None)

# This is part of the visitor pattern from visitor.py.
def accept(self, visitor):
camel_case_name = camel_case_pattern.sub("_", self.__class__.__name__).lower()
if "hog_qlx" in camel_case_name:
camel_case_name = camel_case_name.replace("hog_qlx", "hogqlx_")
method_name = f"visit_{camel_case_name}"
name = camel_case_pattern.sub("_", self.__class__.__name__).lower()

# NOTE: Sync with ./test/test_visitor.py#test_hogql_visitor_naming_exceptions
replacements = {"hog_qlxtag": "hogqlx_tag", "hog_qlxattribute": "hogqlx_attribute", "uuidtype": "uuid_type"}
for old, new in replacements.items():
name = name.replace(old, new)
method_name = f"visit_{name}"
if hasattr(visitor, method_name):
visit = getattr(visitor, method_name)
return visit(self)
Expand Down
16 changes: 16 additions & 0 deletions posthog/hogql/test/test_visitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from posthog.hogql import ast
from posthog.hogql.ast import UUIDType, HogQLXTag, HogQLXAttribute
from posthog.hogql.errors import HogQLException
from posthog.hogql.parser import parse_expr
from posthog.hogql.visitor import CloningVisitor, Visitor, TraversingVisitor
Expand Down Expand Up @@ -137,3 +138,18 @@ def visit_constant(self, node: ast.Constant):
self.assertEqual(str(e.exception), "You tried accessing a forbidden number, perish!")
self.assertEqual(e.exception.start, 4)
self.assertEqual(e.exception.end, 7)

def test_hogql_visitor_naming_exceptions(self):
class NamingCheck(Visitor):
def visit_uuid_type(self, node: ast.Constant):
return "visit_uuid_type"

def visit_hogqlx_tag(self, node: ast.Constant):
return "visit_hogqlx_tag"

def visit_hogqlx_attribute(self, node: ast.Constant):
return "visit_hogqlx_attribute"

assert NamingCheck().visit(UUIDType()) == "visit_uuid_type"
assert NamingCheck().visit(HogQLXAttribute(name="a", value="a")) == "visit_hogqlx_attribute"
assert NamingCheck().visit(HogQLXTag(kind="", attributes=[])) == "visit_hogqlx_tag"

0 comments on commit fbf5355

Please sign in to comment.