Skip to content

Commit

Permalink
Add logic to extract a raw_sessions where clause from a query on the …
Browse files Browse the repository at this point in the history
…sessions table
  • Loading branch information
robbie-c committed Mar 19, 2024
1 parent 22ba3dc commit 53c02e8
Show file tree
Hide file tree
Showing 9 changed files with 611 additions and 5 deletions.
Empty file.
2 changes: 1 addition & 1 deletion posthog/hogql/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class LazyTable(Table):

model_config = ConfigDict(extra="forbid")

def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers) -> Any:
def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers, node) -> Any:
raise NotImplementedException("LazyTable.lazy_select not overridden")


Expand Down
Empty file.
8 changes: 5 additions & 3 deletions posthog/hogql/database/schema/sessions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, List, cast

from posthog.hogql import ast
from posthog.hogql.database.models import (
StringDatabaseField,
DateTimeDatabaseField,
Expand Down Expand Up @@ -62,7 +63,7 @@ def avoid_asterisk_fields(self) -> List[str]:
]


def select_from_sessions_table(requested_fields: Dict[str, List[str | int]]):
def select_from_sessions_table(requested_fields: Dict[str, List[str | int]], node: ast.SelectQuery):
from posthog.hogql import ast

table_name = "raw_sessions"
Expand Down Expand Up @@ -148,8 +149,9 @@ class SessionsTable(LazyTable):
"channel_type": StringDatabaseField(name="channel_type"),
}

def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers):
return select_from_sessions_table(requested_fields)

def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers, node: ast.SelectQuery):
return select_from_sessions_table(requested_fields, node)

def to_printed_clickhouse(self, context):
return "sessions"
Expand Down
Empty file.
317 changes: 317 additions & 0 deletions posthog/hogql/database/schema/util/session_where_clause_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
from typing import Union

from posthog.hogql import ast
from posthog.hogql.ast import CompareOperationOp
from posthog.hogql.database.schema.util.where_clause_visitor import PassThroughHogQLASTVisitor, HogQLASTVisitor

SESSION_BUFFER_DAYS = 3


class AbortOptimisationException(Exception):
pass


class SessionWhereClauseExtractor(PassThroughHogQLASTVisitor):
def get_inner_where(self, parsed_query: ast.SelectQuery) -> Union[ast.Expr, None]:
if not parsed_query.where:
return None

# visit the where clause
where = self.visit(parsed_query.where)

if isinstance(where, ast.Constant):
return None

return where

def visit(self, node: ast.Expr) -> ast.Expr:
try:
return super().visit(node)
except AbortOptimisationException:
return ast.Constant(value=True)

def visit_compare_operation(self, node: ast.CompareOperation) -> ast.Expr:
is_left_constant = is_time_or_interval_constant(node.left)
is_right_constant = is_time_or_interval_constant(node.right)
is_left_timestamp_field = is_simple_timestamp_field_expression(node.left)
is_right_timestamp_field = is_simple_timestamp_field_expression(node.right)

if is_left_constant and is_right_constant:
# just ignore this comparison
return ast.Constant(value=True)

# handle the left side being a min_timestamp expression and the right being constant
if is_left_timestamp_field and is_right_constant:
if node.op == CompareOperationOp.Eq:
return ast.And(
exprs=[
ast.CompareOperation(
op=ast.CompareOperationOp.LtEq,
left=ast.ArithmeticOperation(
op=ast.ArithmeticOperationOp.Sub,
left=rewrite_timestamp_field(node.left),
right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]),
),
right=node.right,
),
ast.CompareOperation(
op=ast.CompareOperationOp.GtEq,
left=ast.ArithmeticOperation(
op=ast.ArithmeticOperationOp.Add,
left=rewrite_timestamp_field(node.left),
right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]),
),
right=node.right,
),
]
)
elif node.op == CompareOperationOp.Gt or node.op == CompareOperationOp.GtEq:
return ast.CompareOperation(
op=ast.CompareOperationOp.GtEq,
left=ast.ArithmeticOperation(
op=ast.ArithmeticOperationOp.Add,
left=rewrite_timestamp_field(node.left),
right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]),
),
right=node.right,
)
elif node.op == CompareOperationOp.Lt or node.op == CompareOperationOp.LtEq:
return ast.CompareOperation(
op=ast.CompareOperationOp.LtEq,
left=ast.ArithmeticOperation(
op=ast.ArithmeticOperationOp.Sub,
left=rewrite_timestamp_field(node.left),
right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]),
),
right=node.right,
)
elif is_right_timestamp_field and is_left_constant:
# let's not duplicate the logic above, instead just flip and it and recurse
if node.op in [
CompareOperationOp.Eq,
CompareOperationOp.Lt,
CompareOperationOp.LtEq,
CompareOperationOp.Gt,
CompareOperationOp.GtEq,
]:
return self.visit(
ast.CompareOperation(
op=CompareOperationOp.Eq
if node.op == CompareOperationOp.Eq
else CompareOperationOp.Lt
if node.op == CompareOperationOp.Gt
else CompareOperationOp.LtEq
if node.op == CompareOperationOp.GtEq
else CompareOperationOp.Gt
if node.op == CompareOperationOp.Lt
else CompareOperationOp.GtEq,
left=node.right,
right=node.left,
)
)

return ast.Constant(value=True)

def visit_arithmetric_operation(self, node: ast.ArithmeticOperation) -> ast.Expr:
# don't even try to handle complex logic
return ast.Constant(value=True)

def visit_not(self, node: ast.Not) -> ast.Expr:
return ast.Constant(value=True)

def visit_call(self, node: ast.Call) -> ast.Expr:
if node.name.lower() == "and":
return self.visit_and(ast.And(exprs=node.args))
elif node.name.lower() == "or":
return self.visit_or(ast.Or(exprs=node.args))
return ast.Constant(value=True)

def visit_field(self, node: ast.Field) -> ast.Expr:
return ast.Constant(value=True)

def visit_constant(self, node: ast.Constant) -> ast.Expr:
return ast.Constant(value=True)

def visit_placeholder(self, node: ast.Placeholder) -> ast.Expr:
raise Exception() # this should never happen, as placeholders should be resolved before this runs

def visit_and(self, node: ast.And) -> ast.Expr:
exprs = [self.visit(expr) for expr in node.exprs]

flattened = []
for expr in exprs:
if isinstance(expr, ast.And):
flattened.extend(expr.exprs)
else:
flattened.append(expr)

if any(isinstance(expr, ast.Constant) and expr.value is False for expr in flattened):
return ast.Constant(value=False)

filtered = [expr for expr in flattened if not isinstance(expr, ast.Constant) or expr.value is not True]
if len(filtered) == 0:
return ast.Constant(value=True)
elif len(filtered) == 1:
return filtered[0]
else:
return ast.And(exprs=filtered)

def visit_or(self, node: ast.Or) -> ast.Expr:
exprs = [self.visit(expr) for expr in node.exprs]

flattened = []
for expr in exprs:
if isinstance(expr, ast.Or):
flattened.extend(expr.exprs)
else:
flattened.append(expr)

if any(isinstance(expr, ast.Constant) and expr.value is True for expr in flattened):
return ast.Constant(value=True)

filtered = [expr for expr in flattened if not isinstance(expr, ast.Constant) or expr.value is not False]
if len(filtered) == 0:
return ast.Constant(value=False)
elif len(filtered) == 1:
return filtered[0]
else:
return ast.Or(exprs=filtered)


def is_time_or_interval_constant(expr: ast.Expr) -> bool:
return IsTimeOrIntervalConstantVisitor().visit(expr)


class IsTimeOrIntervalConstantVisitor(HogQLASTVisitor[bool]):
def visit_constant(self, node: ast.Constant) -> bool:
return True

def visit_compare_operation(self, node: ast.CompareOperation) -> bool:
return self.visit(node.left) and self.visit(node.right)

def visit_arithmetric_operation(self, node: ast.ArithmeticOperation) -> bool:
return self.visit(node.left) and self.visit(node.right)

def visit_call(self, node: ast.Call) -> bool:
# some functions just return a constant
if node.name in ["today", "now"]:
return True
# some functions return a constant if the first argument is a constant
if node.name in [
"parseDateTime64BestEffortOrNull",
"toDateTime",
"toTimeZone",
"assumeNotNull",
"toIntervalYear",
"toIntervalMonth",
"toIntervalWeek",
"toIntervalDay",
"toIntervalHour",
"toIntervalMinute",
"toIntervalSecond",
"toStartOfDay",
"toStartOfWeek",
"toStartOfMonth",
"toStartOfQuarter",
"toStartOfYear",
]:
return self.visit(node.args[0])

# otherwise we don't know, so return False
return False

def visit_field(self, node: ast.Field) -> bool:
return False

def visit_and(self, node: ast.And) -> bool:
return False

def visit_or(self, node: ast.Or) -> bool:
return False

def visit_not(self, node: ast.Not) -> bool:
return False

def visit_placeholder(self, node: ast.Placeholder) -> bool:
raise Exception()


def is_simple_timestamp_field_expression(expr: ast.Expr) -> bool:
return IsSimpleTimestampFieldExpressionVisitor().visit(expr)


class IsSimpleTimestampFieldExpressionVisitor(HogQLASTVisitor[bool]):
def visit_constant(self, node: ast.Constant) -> bool:
return False

def visit_field(self, node: ast.Field) -> bool:
# this is quite leaky, as it doesn't handle aliases, but will handle all of posthog's hogql queries
return (
node.chain == ["min_timestamp"]
or node.chain == ["sessions", "min_timestamp"]
or node.chain == ["s", "min_timestamp"]
or node.chain == ["timestamp"]
or node.chain == ["events", "timestamp"]
or node.chain == ["e", "timestamp"]
)

def visit_arithmetric_operation(self, node: ast.ArithmeticOperation) -> bool:
# only allow the min_timestamp field to be used on one side of the arithmetic operation
return (
self.visit(node.left)
and is_time_or_interval_constant(node.right)
or (self.visit(node.right) and is_time_or_interval_constant(node.left))
)

def visit_call(self, node: ast.Call) -> bool:
# some functions count as a timestamp field expression if their first argument is
if node.name in [
"parseDateTime64BestEffortOrNull",
"toDateTime",
"toTimeZone",
"assumeNotNull",
"toStartOfDay",
"toStartOfWeek",
"toStartOfMonth",
"toStartOfQuarter",
"toStartOfYear",
]:
return self.visit(node.args[0])

# otherwise we don't know, so return False
return False

def visit_compare_operation(self, node: ast.CompareOperation) -> bool:
return False

def visit_and(self, node: ast.And) -> bool:
return False

def visit_or(self, node: ast.Or) -> bool:
return False

def visit_not(self, node: ast.Not) -> bool:
return False

def visit_placeholder(self, node: ast.Placeholder) -> bool:
raise Exception()


def rewrite_timestamp_field(expr: ast.Expr) -> ast.Expr:
return RewriteTimestampFieldVisitor().visit(expr)


class RewriteTimestampFieldVisitor(PassThroughHogQLASTVisitor):
def visit_field(self, node: ast.Field) -> ast.Field:
# this is quite leaky, as it doesn't handle aliases, but will handle all of posthog's hogql queries
if (
node.chain == ["min_timestamp"]
or node.chain == ["sessions", "min_timestamp"]
or node.chain == ["s", "min_timestamp"]
or node.chain == ["timestamp"]
or node.chain == ["events", "timestamp"]
or node.chain == ["e", "timestamp"]
):
return ast.Field(chain=["sessions", "min_timestamp"])
else:
return node
Loading

0 comments on commit 53c02e8

Please sign in to comment.