Skip to content

Commit

Permalink
Use existing Visitor class hierarchy rather than creating a new one
Browse files Browse the repository at this point in the history
  • Loading branch information
robbie-c committed Mar 20, 2024
1 parent 6fe657b commit 72c1b8a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 116 deletions.
2 changes: 1 addition & 1 deletion posthog/hogql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def accept(self, visitor):
return visit(self)
if hasattr(visitor, "visit_unknown"):
return visitor.visit_unknown(self)
raise NotImplementedException(f"Visitor has no method {method_name}")
raise NotImplementedException(f"{visitor.__class__.__name__} has no method {method_name}")


@dataclass(kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from posthog.hogql import ast
from posthog.hogql.ast import CompareOperationOp, ArithmeticOperationOp
from posthog.hogql.database.schema.util.where_clause_visitor import PassThroughHogQLASTVisitor, HogQLASTVisitor
from posthog.hogql.visitor import clone_expr
from posthog.hogql.visitor import clone_expr, CloningVisitor, Visitor

SESSION_BUFFER_DAYS = 3


class SessionWhereClauseExtractor(PassThroughHogQLASTVisitor):
class SessionWhereClauseExtractor(CloningVisitor):
def get_inner_where(self, parsed_query: ast.SelectQuery) -> Optional[ast.Expr]:
if not parsed_query.where:
return None
Expand Down Expand Up @@ -103,7 +102,7 @@ def visit_compare_operation(self, node: ast.CompareOperation) -> ast.Expr:

return ast.Constant(value=True)

def visit_arithmetric_operation(self, node: ast.ArithmeticOperation) -> ast.Expr:
def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> ast.Expr:
# don't even try to handle complex logic
return ast.Constant(value=True)

Expand Down Expand Up @@ -176,14 +175,14 @@ def is_time_or_interval_constant(expr: ast.Expr) -> bool:
return IsTimeOrIntervalConstantVisitor().visit(expr)


class IsTimeOrIntervalConstantVisitor(HogQLASTVisitor[bool]):
class IsTimeOrIntervalConstantVisitor(Visitor):
def visit_constant(self, node: ast.Constant) -> bool:
return True

def visit_compare_operation(self, node: ast.CompareOperation) -> bool:
return self.visit(node.left) and self.visit(node.right)

def visit_arithmetric_operation(self, node: ast.ArithmeticOperation) -> bool:
def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> bool:
return self.visit(node.left) and self.visit(node.right)

def visit_call(self, node: ast.Call) -> bool:
Expand Down Expand Up @@ -240,7 +239,7 @@ def is_simple_timestamp_field_expression(expr: ast.Expr) -> bool:
return IsSimpleTimestampFieldExpressionVisitor().visit(expr)


class IsSimpleTimestampFieldExpressionVisitor(HogQLASTVisitor[bool]):
class IsSimpleTimestampFieldExpressionVisitor(Visitor):
def visit_constant(self, node: ast.Constant) -> bool:
return False

Expand All @@ -255,7 +254,7 @@ def visit_field(self, node: ast.Field) -> bool:
or node.chain == ["e", "timestamp"]
)

def visit_arithmetric_operation(self, node: ast.ArithmeticOperation) -> bool:
def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> bool:
# only allow the min_timestamp field to be used on one side of the arithmetic operation
return (
self.visit(node.left)
Expand All @@ -279,7 +278,7 @@ def visit_call(self, node: ast.Call) -> bool:
return self.visit(node.args[0])

if node.name in ["minus", "add"]:
return self.visit_arithmetric_operation(
return self.visit_arithmetic_operation(
ast.ArithmeticOperation(
op=ArithmeticOperationOp.Sub if node.name == "minus" else ArithmeticOperationOp.Add,
left=node.args[0],
Expand Down Expand Up @@ -313,7 +312,7 @@ def rewrite_timestamp_field(expr: ast.Expr) -> ast.Expr:
return RewriteTimestampFieldVisitor().visit(expr)


class RewriteTimestampFieldVisitor(PassThroughHogQLASTVisitor):
class RewriteTimestampFieldVisitor(CloningVisitor):
def visit_field(self, node: ast.Field) -> ast.Field:
# this is quite leaky, as it doesn't handle aliases, but will handle all of posthog's hogql queries
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from posthog.hogql import ast
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.schema.util.session_where_clause_extractor import SessionWhereClauseExtractor
from posthog.hogql.database.schema.util.where_clause_visitor import PassThroughHogQLASTVisitor
from posthog.hogql.modifiers import create_default_modifiers_for_team
from posthog.hogql.parser import parse_select, parse_expr
from posthog.hogql.printer import prepare_ast_for_printing
from posthog.hogql.visitor import clone_expr
from posthog.test.base import ClickhouseTestMixin, APIBaseTest


Expand All @@ -17,7 +17,7 @@ def f(s: Union[str, ast.Expr], placeholders: Optional[dict[str, ast.Expr]] = Non
expr = parse_expr(s, placeholders=placeholders)
else:
expr = s
return PassThroughHogQLASTVisitor().visit(expr)
return clone_expr(expr, clear_types=True, clear_locations=True)


class TestSessionTimestampInliner:
Expand Down
103 changes: 0 additions & 103 deletions posthog/hogql/database/schema/util/where_clause_visitor.py

This file was deleted.

0 comments on commit 72c1b8a

Please sign in to comment.