From f43d7a9c7a0089e84ce6670c763358b8723172f1 Mon Sep 17 00:00:00 2001 From: Robbie Date: Fri, 22 Mar 2024 08:15:50 +0000 Subject: [PATCH] feat(web-analytics): HogQL sessions where extractor (#20986) * Add logic to extract a raw_sessions where clause from a query on the sessions table * Add failing test - need to figure out types * Replace sessions with raw_sessions * Use existing Visitor class hierarchy rather than creating a new one * Clear types * Add working test * Formatting * Fix typing * Better fix typing * Fix tests * Add a big comment * Increase comment * Add types to Visitor * Handles aliases better * Add test for joining events and sessions * Fix typing * More robust way of looking up tables * Add a test to make sure that collapsing ands works --- posthog/hogql/ast.py | 12 + posthog/hogql/base.py | 2 +- posthog/hogql/database/__init__.py | 0 posthog/hogql/database/models.py | 7 +- posthog/hogql/database/schema/__init__.py | 0 .../hogql/database/schema/cohort_people.py | 3 +- posthog/hogql/database/schema/groups.py | 3 +- posthog/hogql/database/schema/log_entries.py | 5 +- .../database/schema/person_distinct_ids.py | 3 +- posthog/hogql/database/schema/persons.py | 4 +- posthog/hogql/database/schema/persons_pdi.py | 3 +- .../database/schema/session_replay_events.py | 3 +- posthog/hogql/database/schema/sessions.py | 15 +- .../hogql/database/schema/util/__init__.py | 0 .../util/session_where_clause_extractor.py | 398 ++++++++++++++++++ .../test_session_where_clause_extractor.py | 284 +++++++++++++ posthog/hogql/test/test_bytecode.py | 2 +- posthog/hogql/test/test_visitor.py | 2 +- posthog/hogql/transforms/lazy_tables.py | 2 +- posthog/hogql/visitor.py | 13 +- 20 files changed, 730 insertions(+), 31 deletions(-) create mode 100644 posthog/hogql/database/__init__.py create mode 100644 posthog/hogql/database/schema/__init__.py create mode 100644 posthog/hogql/database/schema/util/__init__.py create mode 100644 posthog/hogql/database/schema/util/session_where_clause_extractor.py create mode 100644 posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py diff --git a/posthog/hogql/ast.py b/posthog/hogql/ast.py index a459514f2524f..806226b8f1b9e 100644 --- a/posthog/hogql/ast.py +++ b/posthog/hogql/ast.py @@ -46,8 +46,17 @@ def resolve_constant_type(self, context: HogQLContext): def resolve_database_field(self, context: HogQLContext): if isinstance(self.type, FieldType): return self.type.resolve_database_field(context) + if isinstance(self.type, PropertyType): + return self.type.field_type.resolve_database_field(context) raise NotImplementedException("FieldAliasType.resolve_database_field not implemented") + def resolve_table_type(self, context: HogQLContext): + if isinstance(self.type, FieldType): + return self.type.table_type + if isinstance(self.type, PropertyType): + return self.type.field_type.table_type + raise NotImplementedException("FieldAliasType.resolve_table_type not implemented") + @dataclass(kw_only=True) class BaseTableType(Type): @@ -339,6 +348,9 @@ def get_child(self, name: str | int, context: HogQLContext) -> Type: f'Can not access property "{name}" on field "{self.name}" of type: {type(database_field).__name__}' ) + def resolve_table_type(self, context: HogQLContext): + return self.table_type + @dataclass(kw_only=True) class PropertyType(Type): diff --git a/posthog/hogql/base.py b/posthog/hogql/base.py index fbdafffb2d08c..e8a74025b78be 100644 --- a/posthog/hogql/base.py +++ b/posthog/hogql/base.py @@ -32,7 +32,7 @@ def accept(self, visitor): return visit(self) if hasattr(visitor, "visit_unknown"): return visitor.visit_unknown(self) - raise NotImplementedException(f"Visitor has no method {method_name}") + raise NotImplementedException(f"{visitor.__class__.__name__} has no method {method_name}") @dataclass(kw_only=True) diff --git a/posthog/hogql/database/__init__.py b/posthog/hogql/database/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/posthog/hogql/database/models.py b/posthog/hogql/database/models.py index 95a00595c6472..d2da7868a7f9c 100644 --- a/posthog/hogql/database/models.py +++ b/posthog/hogql/database/models.py @@ -3,7 +3,6 @@ from posthog.hogql.base import Expr from posthog.hogql.errors import HogQLException, NotImplementedException -from posthog.schema import HogQLQueryModifiers if TYPE_CHECKING: from posthog.hogql.context import HogQLContext @@ -126,12 +125,14 @@ def resolve_table(self, context: "HogQLContext") -> Table: class LazyTable(Table): """ - A table that is replaced with a subquery returned from `lazy_select(requested_fields: Dict[name, chain], modifiers: HogQLQueryModifiers)` + A table that is replaced with a subquery returned from `lazy_select(requested_fields: Dict[name, chain], modifiers: HogQLQueryModifiers, node: SelectQuery)` """ model_config = ConfigDict(extra="forbid") - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers) -> Any: + def lazy_select( + self, requested_fields: Dict[str, List[str | int]], context: "HogQLContext", node: "SelectQuery" + ) -> Any: raise NotImplementedException("LazyTable.lazy_select not overridden") diff --git a/posthog/hogql/database/schema/__init__.py b/posthog/hogql/database/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/posthog/hogql/database/schema/cohort_people.py b/posthog/hogql/database/schema/cohort_people.py index 72080419b7355..11723f0194619 100644 --- a/posthog/hogql/database/schema/cohort_people.py +++ b/posthog/hogql/database/schema/cohort_people.py @@ -9,7 +9,6 @@ FieldOrTable, ) from posthog.hogql.database.schema.persons import join_with_persons_table -from posthog.schema import HogQLQueryModifiers COHORT_PEOPLE_FIELDS = { "person_id": StringDatabaseField(name="person_id"), @@ -67,7 +66,7 @@ def to_printed_hogql(self): class CohortPeople(LazyTable): fields: Dict[str, FieldOrTable] = COHORT_PEOPLE_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): return select_from_cohort_people_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/groups.py b/posthog/hogql/database/schema/groups.py index bb237d68e8070..3b9de7f08befc 100644 --- a/posthog/hogql/database/schema/groups.py +++ b/posthog/hogql/database/schema/groups.py @@ -13,7 +13,6 @@ FieldOrTable, ) from posthog.hogql.errors import HogQLException -from posthog.schema import HogQLQueryModifiers GROUPS_TABLE_FIELDS = { "index": IntegerDatabaseField(name="group_type_index"), @@ -83,7 +82,7 @@ def to_printed_hogql(self): class GroupsTable(LazyTable): fields: Dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): return select_from_groups_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/log_entries.py b/posthog/hogql/database/schema/log_entries.py index c14e90e26da50..9f5dc816ac4b0 100644 --- a/posthog/hogql/database/schema/log_entries.py +++ b/posthog/hogql/database/schema/log_entries.py @@ -9,7 +9,6 @@ LazyTable, FieldOrTable, ) -from posthog.schema import HogQLQueryModifiers LOG_ENTRIES_FIELDS: Dict[str, FieldOrTable] = { "team_id": IntegerDatabaseField(name="team_id"), @@ -35,7 +34,7 @@ def to_printed_hogql(self): class ReplayConsoleLogsLogEntriesTable(LazyTable): fields: Dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): fields: List[ast.Expr] = [ast.Field(chain=["log_entries"] + chain) for name, chain in requested_fields.items()] return ast.SelectQuery( @@ -58,7 +57,7 @@ def to_printed_hogql(self): class BatchExportLogEntriesTable(LazyTable): fields: Dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): fields: List[ast.Expr] = [ast.Field(chain=["log_entries"] + chain) for name, chain in requested_fields.items()] return ast.SelectQuery( diff --git a/posthog/hogql/database/schema/person_distinct_ids.py b/posthog/hogql/database/schema/person_distinct_ids.py index 02144b35fc3d8..3304eccda862e 100644 --- a/posthog/hogql/database/schema/person_distinct_ids.py +++ b/posthog/hogql/database/schema/person_distinct_ids.py @@ -14,7 +14,6 @@ ) from posthog.hogql.database.schema.persons import join_with_persons_table from posthog.hogql.errors import HogQLException -from posthog.schema import HogQLQueryModifiers PERSON_DISTINCT_IDS_FIELDS = { "team_id": IntegerDatabaseField(name="team_id"), @@ -82,7 +81,7 @@ def to_printed_hogql(self): class PersonDistinctIdsTable(LazyTable): fields: Dict[str, FieldOrTable] = PERSON_DISTINCT_IDS_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): return select_from_person_distinct_ids_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/persons.py b/posthog/hogql/database/schema/persons.py index a248da56b7307..c7abdd89e14c6 100644 --- a/posthog/hogql/database/schema/persons.py +++ b/posthog/hogql/database/schema/persons.py @@ -123,8 +123,8 @@ def to_printed_hogql(self): class PersonsTable(LazyTable): fields: Dict[str, FieldOrTable] = PERSONS_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): - return select_from_persons_table(requested_fields, modifiers) + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + return select_from_persons_table(requested_fields, context.modifiers) def to_printed_clickhouse(self, context): return "person" diff --git a/posthog/hogql/database/schema/persons_pdi.py b/posthog/hogql/database/schema/persons_pdi.py index 9f476f407b4d2..195643b90c08c 100644 --- a/posthog/hogql/database/schema/persons_pdi.py +++ b/posthog/hogql/database/schema/persons_pdi.py @@ -10,7 +10,6 @@ FieldOrTable, ) from posthog.hogql.errors import HogQLException -from posthog.schema import HogQLQueryModifiers # :NOTE: We already have person_distinct_ids.py, which most tables link to. This persons_pdi.py is a hack to @@ -63,7 +62,7 @@ class PersonsPDITable(LazyTable): "person_id": StringDatabaseField(name="person_id"), } - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): return persons_pdi_select(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/session_replay_events.py b/posthog/hogql/database/schema/session_replay_events.py index c9d564c7d4588..baaecef89e049 100644 --- a/posthog/hogql/database/schema/session_replay_events.py +++ b/posthog/hogql/database/schema/session_replay_events.py @@ -15,7 +15,6 @@ PersonDistinctIdsTable, join_with_person_distinct_ids_table, ) -from posthog.schema import HogQLQueryModifiers RAW_ONLY_FIELDS = ["min_first_timestamp", "max_last_timestamp"] @@ -115,7 +114,7 @@ class SessionReplayEventsTable(LazyTable): "first_url": StringDatabaseField(name="first_url"), } - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): return select_from_session_replay_events_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/sessions.py b/posthog/hogql/database/schema/sessions.py index 2a4865798eeb8..770daceaa23c5 100644 --- a/posthog/hogql/database/schema/sessions.py +++ b/posthog/hogql/database/schema/sessions.py @@ -1,5 +1,7 @@ from typing import Dict, List, cast +from posthog.hogql import ast +from posthog.hogql.context import HogQLContext from posthog.hogql.database.models import ( StringDatabaseField, DateTimeDatabaseField, @@ -11,7 +13,7 @@ LazyTable, ) from posthog.hogql.database.schema.channel_type import create_channel_type_expr -from posthog.schema import HogQLQueryModifiers +from posthog.hogql.database.schema.util.session_where_clause_extractor import SessionMinTimestampWhereClauseExtractor SESSIONS_COMMON_FIELDS: Dict[str, FieldOrTable] = { @@ -62,7 +64,9 @@ def avoid_asterisk_fields(self) -> List[str]: ] -def select_from_sessions_table(requested_fields: Dict[str, List[str | int]]): +def select_from_sessions_table( + requested_fields: Dict[str, List[str | int]], node: ast.SelectQuery, context: HogQLContext +): from posthog.hogql import ast table_name = "raw_sessions" @@ -134,10 +138,13 @@ def select_from_sessions_table(requested_fields: Dict[str, List[str | int]]): ) group_by_fields.append(ast.Field(chain=cast(list[str | int], [table_name]) + chain)) + where = SessionMinTimestampWhereClauseExtractor(context).get_inner_where(node) + return ast.SelectQuery( select=select_fields, select_from=ast.JoinExpr(table=ast.Field(chain=[table_name])), group_by=group_by_fields, + where=where, ) @@ -148,8 +155,8 @@ class SessionsTable(LazyTable): "channel_type": StringDatabaseField(name="channel_type"), } - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): - return select_from_sessions_table(requested_fields) + def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node: ast.SelectQuery): + return select_from_sessions_table(requested_fields, node, context) def to_printed_clickhouse(self, context): return "sessions" diff --git a/posthog/hogql/database/schema/util/__init__.py b/posthog/hogql/database/schema/util/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/posthog/hogql/database/schema/util/session_where_clause_extractor.py b/posthog/hogql/database/schema/util/session_where_clause_extractor.py new file mode 100644 index 0000000000000..83933bdde8b85 --- /dev/null +++ b/posthog/hogql/database/schema/util/session_where_clause_extractor.py @@ -0,0 +1,398 @@ +from dataclasses import dataclass +from typing import Optional + +from posthog.hogql import ast +from posthog.hogql.ast import CompareOperationOp, ArithmeticOperationOp +from posthog.hogql.context import HogQLContext +from posthog.hogql.database.models import DatabaseField + +from posthog.hogql.visitor import clone_expr, CloningVisitor, Visitor + +SESSION_BUFFER_DAYS = 3 + + +@dataclass +class SessionMinTimestampWhereClauseExtractor(CloningVisitor): + """This class extracts the Where clause from the lazy sessions table, to the clickhouse sessions table. + + The sessions table in Clickhouse is an AggregatingMergeTree, and will have one row per session per day. This means that + when we want to query sessions, we need to pre-group these rows, so that we only have one row per session. + + We hide this detail using a lazy table, but to make querying the underlying Clickhouse table faster, we can inline the + min_timestamp where conditions from the select on the outer lazy table to the select on the inner real table. + + This class is called on the select query of the lazy table, and will return the where clause that should be applied to + the inner table. + + As a query can be unreasonably complex, we only handle simple cases, but this class is designed to fail-safe. If it + can't reason about a particular expression, it will just return a constant True, i.e. fetch more rows than necessary. + + This means that we can incrementally add support for more complex queries, without breaking existing queries, by + handling more cases. + + Some examples of failing-safe: + + `SELECT * FROM sessions where min_timestamp > '2022-01-01' AND f(session_id)` + only the` min_timestamp > '2022-01-01'` part is relevant, so we can ignore the `f(session_id)` part, and it is safe + to replace it with a constant True, which collapses the AND to just the `min_timestamp > '2022-01-01'` part. + + `SELECT * FROM sessions where min_timestamp > '2022-01-01' OR f(session_id)` + only the` min_timestamp > '2022-01-01'` part is relevant, and turning the `f(session_id)` part into a constant True + would collapse the OR to True. In this case we return None as no pre-filtering is possible. + + All min_timestamp comparisons are given a buffer of SESSION_BUFFER_DAYS each side, to ensure that we collect all the + relevant rows for each session. + """ + + context: HogQLContext + clear_types: bool = False + clear_locations: bool = False + + def get_inner_where(self, parsed_query: ast.SelectQuery) -> Optional[ast.Expr]: + if not parsed_query.where: + return None + + # visit the where clause + where = self.visit(parsed_query.where) + + if isinstance(where, ast.Constant): + return None + + return clone_expr(where, clear_types=True, clear_locations=True) + + def visit_compare_operation(self, node: ast.CompareOperation) -> ast.Expr: + is_left_constant = is_time_or_interval_constant(node.left) + is_right_constant = is_time_or_interval_constant(node.right) + is_left_timestamp_field = is_simple_timestamp_field_expression(node.left, self.context) + is_right_timestamp_field = is_simple_timestamp_field_expression(node.right, self.context) + + if is_left_constant and is_right_constant: + # just ignore this comparison + return ast.Constant(value=True) + + # handle the left side being a min_timestamp expression and the right being constant + if is_left_timestamp_field and is_right_constant: + if node.op == CompareOperationOp.Eq: + return ast.And( + exprs=[ + ast.CompareOperation( + op=ast.CompareOperationOp.LtEq, + left=ast.ArithmeticOperation( + op=ast.ArithmeticOperationOp.Sub, + left=rewrite_timestamp_field(node.left, self.context), + right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]), + ), + right=node.right, + ), + ast.CompareOperation( + op=ast.CompareOperationOp.GtEq, + left=ast.ArithmeticOperation( + op=ast.ArithmeticOperationOp.Add, + left=rewrite_timestamp_field(node.left, self.context), + right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]), + ), + right=node.right, + ), + ] + ) + elif node.op == CompareOperationOp.Gt or node.op == CompareOperationOp.GtEq: + return ast.CompareOperation( + op=ast.CompareOperationOp.GtEq, + left=ast.ArithmeticOperation( + op=ast.ArithmeticOperationOp.Add, + left=rewrite_timestamp_field(node.left, self.context), + right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]), + ), + right=node.right, + ) + elif node.op == CompareOperationOp.Lt or node.op == CompareOperationOp.LtEq: + return ast.CompareOperation( + op=ast.CompareOperationOp.LtEq, + left=ast.ArithmeticOperation( + op=ast.ArithmeticOperationOp.Sub, + left=rewrite_timestamp_field(node.left, self.context), + right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]), + ), + right=node.right, + ) + elif is_right_timestamp_field and is_left_constant: + # let's not duplicate the logic above, instead just flip and it and recurse + if node.op in [ + CompareOperationOp.Eq, + CompareOperationOp.Lt, + CompareOperationOp.LtEq, + CompareOperationOp.Gt, + CompareOperationOp.GtEq, + ]: + return self.visit( + ast.CompareOperation( + op=CompareOperationOp.Eq + if node.op == CompareOperationOp.Eq + else CompareOperationOp.Lt + if node.op == CompareOperationOp.Gt + else CompareOperationOp.LtEq + if node.op == CompareOperationOp.GtEq + else CompareOperationOp.Gt + if node.op == CompareOperationOp.Lt + else CompareOperationOp.GtEq, + left=node.right, + right=node.left, + ) + ) + + return ast.Constant(value=True) + + def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> ast.Expr: + # don't even try to handle complex logic + return ast.Constant(value=True) + + def visit_not(self, node: ast.Not) -> ast.Expr: + return ast.Constant(value=True) + + def visit_call(self, node: ast.Call) -> ast.Expr: + if node.name == "and": + return self.visit_and(ast.And(exprs=node.args)) + elif node.name == "or": + return self.visit_or(ast.Or(exprs=node.args)) + return ast.Constant(value=True) + + def visit_field(self, node: ast.Field) -> ast.Expr: + return ast.Constant(value=True) + + def visit_constant(self, node: ast.Constant) -> ast.Expr: + return ast.Constant(value=True) + + def visit_placeholder(self, node: ast.Placeholder) -> ast.Expr: + raise Exception() # this should never happen, as placeholders should be resolved before this runs + + def visit_and(self, node: ast.And) -> ast.Expr: + exprs = [self.visit(expr) for expr in node.exprs] + + flattened = [] + for expr in exprs: + if isinstance(expr, ast.And): + flattened.extend(expr.exprs) + else: + flattened.append(expr) + + if any(isinstance(expr, ast.Constant) and expr.value is False for expr in flattened): + return ast.Constant(value=False) + + filtered = [expr for expr in flattened if not isinstance(expr, ast.Constant) or expr.value is not True] + if len(filtered) == 0: + return ast.Constant(value=True) + elif len(filtered) == 1: + return filtered[0] + else: + return ast.And(exprs=filtered) + + def visit_or(self, node: ast.Or) -> ast.Expr: + exprs = [self.visit(expr) for expr in node.exprs] + + flattened = [] + for expr in exprs: + if isinstance(expr, ast.Or): + flattened.extend(expr.exprs) + else: + flattened.append(expr) + + if any(isinstance(expr, ast.Constant) and expr.value is True for expr in flattened): + return ast.Constant(value=True) + + filtered = [expr for expr in flattened if not isinstance(expr, ast.Constant) or expr.value is not False] + if len(filtered) == 0: + return ast.Constant(value=False) + elif len(filtered) == 1: + return filtered[0] + else: + return ast.Or(exprs=filtered) + + def visit_alias(self, node: ast.Alias) -> ast.Expr: + return self.visit(node.expr) + + +def is_time_or_interval_constant(expr: ast.Expr) -> bool: + return IsTimeOrIntervalConstantVisitor().visit(expr) + + +class IsTimeOrIntervalConstantVisitor(Visitor[bool]): + def visit_constant(self, node: ast.Constant) -> bool: + return True + + def visit_compare_operation(self, node: ast.CompareOperation) -> bool: + return self.visit(node.left) and self.visit(node.right) + + def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> bool: + return self.visit(node.left) and self.visit(node.right) + + def visit_call(self, node: ast.Call) -> bool: + # some functions just return a constant + if node.name in ["today", "now"]: + return True + # some functions return a constant if the first argument is a constant + if node.name in [ + "parseDateTime64BestEffortOrNull", + "toDateTime", + "toTimeZone", + "assumeNotNull", + "toIntervalYear", + "toIntervalMonth", + "toIntervalWeek", + "toIntervalDay", + "toIntervalHour", + "toIntervalMinute", + "toIntervalSecond", + "toStartOfDay", + "toStartOfWeek", + "toStartOfMonth", + "toStartOfQuarter", + "toStartOfYear", + ]: + return self.visit(node.args[0]) + + if node.name in ["minus", "add"]: + return all(self.visit(arg) for arg in node.args) + + # otherwise we don't know, so return False + return False + + def visit_field(self, node: ast.Field) -> bool: + return False + + def visit_and(self, node: ast.And) -> bool: + return False + + def visit_or(self, node: ast.Or) -> bool: + return False + + def visit_not(self, node: ast.Not) -> bool: + return False + + def visit_placeholder(self, node: ast.Placeholder) -> bool: + raise Exception() + + def visit_alias(self, node: ast.Alias) -> bool: + return self.visit(node.expr) + + +def is_simple_timestamp_field_expression(expr: ast.Expr, context: HogQLContext) -> bool: + return IsSimpleTimestampFieldExpressionVisitor(context).visit(expr) + + +@dataclass +class IsSimpleTimestampFieldExpressionVisitor(Visitor[bool]): + context: HogQLContext + + def visit_constant(self, node: ast.Constant) -> bool: + return False + + def visit_field(self, node: ast.Field) -> bool: + if node.type and isinstance(node.type, ast.FieldType): + resolved_field = node.type.resolve_database_field(self.context) + if resolved_field and isinstance(resolved_field, DatabaseField) and resolved_field: + return resolved_field.name in ["min_timestamp", "timestamp"] + # no type information, so just use the name of the field + return node.chain[-1] in ["min_timestamp", "timestamp"] + + def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> bool: + # only allow the min_timestamp field to be used on one side of the arithmetic operation + return ( + self.visit(node.left) + and is_time_or_interval_constant(node.right) + or (self.visit(node.right) and is_time_or_interval_constant(node.left)) + ) + + def visit_call(self, node: ast.Call) -> bool: + # some functions count as a timestamp field expression if their first argument is + if node.name in [ + "parseDateTime64BestEffortOrNull", + "toDateTime", + "toTimeZone", + "assumeNotNull", + "toStartOfDay", + "toStartOfWeek", + "toStartOfMonth", + "toStartOfQuarter", + "toStartOfYear", + ]: + return self.visit(node.args[0]) + + if node.name in ["minus", "add"]: + return self.visit_arithmetic_operation( + ast.ArithmeticOperation( + op=ArithmeticOperationOp.Sub if node.name == "minus" else ArithmeticOperationOp.Add, + left=node.args[0], + right=node.args[1], + ) + ) + + # otherwise we don't know, so return False + return False + + def visit_compare_operation(self, node: ast.CompareOperation) -> bool: + return False + + def visit_and(self, node: ast.And) -> bool: + return False + + def visit_or(self, node: ast.Or) -> bool: + return False + + def visit_not(self, node: ast.Not) -> bool: + return False + + def visit_placeholder(self, node: ast.Placeholder) -> bool: + raise Exception() + + def visit_alias(self, node: ast.Alias) -> bool: + from posthog.hogql.database.schema.events import EventsTable + from posthog.hogql.database.schema.sessions import SessionsTable + + if node.type and isinstance(node.type, ast.FieldAliasType): + resolved_field = node.type.resolve_database_field(self.context) + table_type = node.type.resolve_table_type(self.context) + if not table_type: + return False + return ( + isinstance(table_type, ast.TableType) + and isinstance(table_type.table, EventsTable) + and resolved_field.name == "timestamp" + ) or ( + isinstance(table_type, ast.LazyTableType) + and isinstance(table_type.table, SessionsTable) + and resolved_field.name == "min_timestamp" + ) + + return self.visit(node.expr) + + +def rewrite_timestamp_field(expr: ast.Expr, context: HogQLContext) -> ast.Expr: + return RewriteTimestampFieldVisitor(context).visit(expr) + + +class RewriteTimestampFieldVisitor(CloningVisitor): + context: HogQLContext + + def __init__(self, context: HogQLContext, *args, **kwargs): + super().__init__(*args, **kwargs) + self.context = context + + def visit_field(self, node: ast.Field) -> ast.Field: + from posthog.hogql.database.schema.events import EventsTable + from posthog.hogql.database.schema.sessions import SessionsTable + + if node.type and isinstance(node.type, ast.FieldType): + resolved_field = node.type.resolve_database_field(self.context) + table = node.type.resolve_table_type(self.context).table + if resolved_field and isinstance(resolved_field, DatabaseField): + if (isinstance(table, EventsTable) and resolved_field.name == "timestamp") or ( + isinstance(table, SessionsTable) and resolved_field.name == "min_timestamp" + ): + return ast.Field(chain=["raw_sessions", "min_timestamp"]) + # no type information, so just use the name of the field + if node.chain[-1] in ["min_timestamp", "timestamp"]: + return ast.Field(chain=["raw_sessions", "min_timestamp"]) + return node + + def visit_alias(self, node: ast.Alias) -> ast.Expr: + return self.visit(node.expr) diff --git a/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py new file mode 100644 index 0000000000000..bc5324e739ad9 --- /dev/null +++ b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py @@ -0,0 +1,284 @@ +from typing import Union, Optional, Dict + +from posthog.hogql import ast +from posthog.hogql.context import HogQLContext +from posthog.hogql.database.schema.util.session_where_clause_extractor import SessionMinTimestampWhereClauseExtractor +from posthog.hogql.modifiers import create_default_modifiers_for_team +from posthog.hogql.parser import parse_select, parse_expr +from posthog.hogql.printer import prepare_ast_for_printing, print_prepared_ast +from posthog.hogql.visitor import clone_expr +from posthog.test.base import ClickhouseTestMixin, APIBaseTest + + +def f(s: Union[str, ast.Expr, None], placeholders: Optional[dict[str, ast.Expr]] = None) -> Union[ast.Expr, None]: + if s is None: + return None + if isinstance(s, str): + expr = parse_expr(s, placeholders=placeholders) + else: + expr = s + return clone_expr(expr, clear_types=True, clear_locations=True) + + +def parse( + s: str, + placeholders: Optional[Dict[str, ast.Expr]] = None, +) -> ast.SelectQuery: + parsed = parse_select(s, placeholders=placeholders) + assert isinstance(parsed, ast.SelectQuery) + return parsed + + +class TestSessionTimestampInliner(ClickhouseTestMixin, APIBaseTest): + @property + def inliner(self): + team = self.team + modifiers = create_default_modifiers_for_team(team) + context = HogQLContext( + team_id=team.pk, + team=team, + enable_select_queries=True, + modifiers=modifiers, + ) + return SessionMinTimestampWhereClauseExtractor(context) + + def test_handles_select_with_no_where_claus(self): + inner_where = self.inliner.get_inner_where(parse("SELECT * FROM sessions")) + assert inner_where is None + + def test_handles_select_with_eq(self): + actual = f(self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE min_timestamp = '2021-01-01'"))) + expected = f( + "((raw_sessions.min_timestamp - toIntervalDay(3)) <= '2021-01-01') AND ((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-01')" + ) + assert expected == actual + + def test_handles_select_with_eq_flipped(self): + actual = f(self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE '2021-01-01' = min_timestamp"))) + expected = f( + "((raw_sessions.min_timestamp - toIntervalDay(3)) <= '2021-01-01') AND ((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-01')" + ) + assert expected == actual + + def test_handles_select_with_simple_gt(self): + actual = f(self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE min_timestamp > '2021-01-01'"))) + expected = f("((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-01')") + assert expected == actual + + def test_handles_select_with_simple_gte(self): + actual = f(self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE min_timestamp >= '2021-01-01'"))) + expected = f("((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-01')") + assert expected == actual + + def test_handles_select_with_simple_lt(self): + actual = f(self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE min_timestamp < '2021-01-01'"))) + expected = f("((raw_sessions.min_timestamp - toIntervalDay(3)) <= '2021-01-01')") + assert expected == actual + + def test_handles_select_with_simple_lte(self): + actual = f(self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE min_timestamp <= '2021-01-01'"))) + expected = f("((raw_sessions.min_timestamp - toIntervalDay(3)) <= '2021-01-01')") + assert expected == actual + + def test_select_with_placeholder(self): + actual = f( + self.inliner.get_inner_where( + parse( + "SELECT * FROM sessions WHERE min_timestamp > {timestamp}", + placeholders={"timestamp": ast.Constant(value="2021-01-01")}, + ) + ) + ) + expected = f("((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-01')") + assert expected == actual + + def test_unrelated_equals(self): + actual = self.inliner.get_inner_where( + parse("SELECT * FROM sessions WHERE initial_utm_campaign = initial_utm_source") + ) + assert actual is None + + def test_timestamp_and(self): + actual = f( + self.inliner.get_inner_where( + parse("SELECT * FROM sessions WHERE and(min_timestamp >= '2021-01-01', min_timestamp <= '2021-01-03')") + ) + ) + expected = f( + "((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-01') AND ((raw_sessions.min_timestamp - toIntervalDay(3)) <= '2021-01-03')" + ) + assert expected == actual + + def test_timestamp_or(self): + actual = f( + self.inliner.get_inner_where( + parse("SELECT * FROM sessions WHERE and(min_timestamp <= '2021-01-01', min_timestamp >= '2021-01-03')") + ) + ) + expected = f( + "((raw_sessions.min_timestamp - toIntervalDay(3)) <= '2021-01-01') AND ((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-03')" + ) + assert expected == actual + + def test_unrelated_function(self): + actual = f(self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE like('a', 'b')"))) + assert actual is None + + def test_timestamp_unrelated_function(self): + actual = f( + self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE like(toString(min_timestamp), 'b')")) + ) + assert actual is None + + def test_timestamp_unrelated_function_timestamp(self): + actual = f( + self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE like(toString(min_timestamp), 'b')")) + ) + assert actual is None + + def test_ambiguous_or(self): + actual = f( + self.inliner.get_inner_where( + parse( + "SELECT * FROM sessions WHERE or(min_timestamp > '2021-01-03', like(toString(min_timestamp), 'b'))" + ) + ) + ) + assert actual is None + + def test_ambiguous_and(self): + actual = f( + self.inliner.get_inner_where( + parse( + "SELECT * FROM sessions WHERE and(min_timestamp > '2021-01-03', like(toString(min_timestamp), 'b'))" + ) + ) + ) + assert actual == f("(raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-03'") + + def test_join(self): + actual = f( + self.inliner.get_inner_where( + parse( + "SELECT * FROM events JOIN sessions ON events.session_id = raw_sessions.session_id WHERE min_timestamp > '2021-01-03'" + ) + ) + ) + expected = f("((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-03')") + assert expected == actual + + def test_join_using_events_timestamp_filter(self): + actual = f( + self.inliner.get_inner_where( + parse( + "SELECT * FROM events JOIN sessions ON events.session_id = raw_sessions.session_id WHERE timestamp > '2021-01-03'" + ) + ) + ) + expected = f("((raw_sessions.min_timestamp + toIntervalDay(3)) >= '2021-01-03')") + assert expected == actual + + def test_minus(self): + actual = f(self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE min_timestamp >= today() - 2"))) + expected = f("((raw_sessions.min_timestamp + toIntervalDay(3)) >= (today() - 2))") + assert expected == actual + + def test_minus_function(self): + actual = f( + self.inliner.get_inner_where(parse("SELECT * FROM sessions WHERE min_timestamp >= minus(today() , 2)")) + ) + expected = f("((raw_sessions.min_timestamp + toIntervalDay(3)) >= minus(today(), 2))") + assert expected == actual + + def test_real_example(self): + actual = f( + self.inliner.get_inner_where( + parse( + "SELECT * FROM events JOIN sessions ON events.session_id = raw_sessions.session_id WHERE event = '$pageview' AND toTimeZone(timestamp, 'US/Pacific') >= toDateTime('2024-03-12 00:00:00', 'US/Pacific') AND toTimeZone(timestamp, 'US/Pacific') <= toDateTime('2024-03-19 23:59:59', 'US/Pacific')" + ) + ) + ) + expected = f( + "(toTimeZone(raw_sessions.min_timestamp, 'US/Pacific') + toIntervalDay(3)) >= toDateTime('2024-03-12 00:00:00', 'US/Pacific') AND (toTimeZone(raw_sessions.min_timestamp, 'US/Pacific') - toIntervalDay(3)) <= toDateTime('2024-03-19 23:59:59', 'US/Pacific') " + ) + assert expected == actual + + def test_collapse_and(self): + actual = f( + self.inliner.get_inner_where( + parse( + "SELECT * FROM sesions WHERE event = '$pageview' AND (TRUE AND (TRUE AND TRUE AND (timestamp >= '2024-03-12' AND TRUE)))" + ) + ) + ) + expected = f("(raw_sessions.min_timestamp + toIntervalDay(3)) >= '2024-03-12'") + assert expected == actual + + +class TestSessionsQueriesHogQLToClickhouse(ClickhouseTestMixin, APIBaseTest): + def print_query(self, query: str) -> str: + team = self.team + modifiers = create_default_modifiers_for_team(team) + context = HogQLContext( + team_id=team.pk, + team=team, + enable_select_queries=True, + modifiers=modifiers, + ) + prepared_ast = prepare_ast_for_printing(node=parse(query), context=context, dialect="clickhouse") + pretty = print_prepared_ast(prepared_ast, context=context, dialect="clickhouse", pretty=True) + return pretty + + def test_select_with_timestamp(self): + actual = self.print_query("SELECT session_id FROM sessions WHERE min_timestamp > '2021-01-01'") + expected = f"""SELECT + sessions.session_id AS session_id +FROM + (SELECT + sessions.session_id AS session_id, + min(sessions.min_timestamp) AS min_timestamp + FROM + sessions + WHERE + and(equals(sessions.team_id, {self.team.id}), ifNull(greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_0)s), toIntervalDay(3)), %(hogql_val_1)s), 0)) + GROUP BY + sessions.session_id, + sessions.session_id) AS sessions +WHERE + ifNull(greater(toTimeZone(sessions.min_timestamp, %(hogql_val_2)s), %(hogql_val_3)s), 0) +LIMIT 10000""" + assert expected == actual + + def test_join_with_events(self): + actual = self.print_query( + """ +SELECT + sessions.session_id, + uniq(uuid) +FROM events +JOIN sessions +ON events.$session_id = sessions.session_id +WHERE events.timestamp > '2021-01-01' +GROUP BY sessions.session_id +""" + ) + expected = f"""SELECT + sessions.session_id AS session_id, + uniq(events.uuid) +FROM + events + JOIN (SELECT + sessions.session_id AS session_id + FROM + sessions + WHERE + and(equals(sessions.team_id, {self.team.id}), ifNull(greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_0)s), toIntervalDay(3)), %(hogql_val_1)s), 0)) + GROUP BY + sessions.session_id, + sessions.session_id) AS sessions ON equals(events.`$session_id`, sessions.session_id) +WHERE + and(equals(events.team_id, {self.team.id}), greater(toTimeZone(events.timestamp, %(hogql_val_2)s), %(hogql_val_3)s)) +GROUP BY + sessions.session_id +LIMIT 10000""" + assert expected == actual diff --git a/posthog/hogql/test/test_bytecode.py b/posthog/hogql/test/test_bytecode.py index cf0b8113b574d..f7d810700e74a 100644 --- a/posthog/hogql/test/test_bytecode.py +++ b/posthog/hogql/test/test_bytecode.py @@ -130,7 +130,7 @@ def test_bytecode_create(self): def test_bytecode_create_error(self): with self.assertRaises(NotImplementedException) as e: to_bytecode("(select 1)") - self.assertEqual(str(e.exception), "Visitor has no method visit_select_query") + self.assertEqual(str(e.exception), "BytecodeBuilder has no method visit_select_query") with self.assertRaises(NotImplementedException) as e: to_bytecode("1 in cohort 2") diff --git a/posthog/hogql/test/test_visitor.py b/posthog/hogql/test/test_visitor.py index 8aa6689328fbf..a01193f788d5f 100644 --- a/posthog/hogql/test/test_visitor.py +++ b/posthog/hogql/test/test_visitor.py @@ -125,7 +125,7 @@ def visit_arithmetic_operation(self, node: ast.ArithmeticOperation): with self.assertRaises(HogQLException) as e: UnknownNotDefinedVisitor().visit(parse_expr("1 + 3 / 'asd2'")) - self.assertEqual(str(e.exception), "Visitor has no method visit_constant") + self.assertEqual(str(e.exception), "UnknownNotDefinedVisitor has no method visit_constant") def test_hogql_exception_start_end(self): class EternalVisitor(TraversingVisitor): diff --git a/posthog/hogql/transforms/lazy_tables.py b/posthog/hogql/transforms/lazy_tables.py index bdbb322d54397..df8ce6962259c 100644 --- a/posthog/hogql/transforms/lazy_tables.py +++ b/posthog/hogql/transforms/lazy_tables.py @@ -309,7 +309,7 @@ def create_override(table_name: str, field_chain: List[str | int]) -> None: # For all the collected tables, create the subqueries, and add them to the table. for table_name, table_to_add in tables_to_add.items(): - subquery = table_to_add.lazy_table.lazy_select(table_to_add.fields_accessed, self.context.modifiers) + subquery = table_to_add.lazy_table.lazy_select(table_to_add.fields_accessed, self.context, node=node) subquery = cast(ast.SelectQuery, clone_expr(subquery, clear_locations=True)) subquery = cast(ast.SelectQuery, resolve_types(subquery, self.context, self.dialect, [node.type])) old_table_type = select_type.tables[table_name] diff --git a/posthog/hogql/visitor.py b/posthog/hogql/visitor.py index c11856169297f..2bf968abf2ab0 100644 --- a/posthog/hogql/visitor.py +++ b/posthog/hogql/visitor.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, TypeVar, Generic, Any from posthog.hogql import ast from posthog.hogql.base import AST, Expr @@ -14,8 +14,11 @@ def clear_locations(expr: Expr) -> Expr: return CloningVisitor(clear_locations=True).visit(expr) -class Visitor(object): - def visit(self, node: AST): +T = TypeVar("T") + + +class Visitor(Generic[T]): + def visit(self, node: AST) -> T: if node is None: return node @@ -28,7 +31,7 @@ def visit(self, node: AST): raise e -class TraversingVisitor(Visitor): +class TraversingVisitor(Visitor[None]): """Visitor that traverses the AST tree without returning anything""" def visit_expr(self, node: Expr): @@ -258,7 +261,7 @@ def visit_hogqlx_attribute(self, node: ast.HogQLXAttribute): self.visit(node.value) -class CloningVisitor(Visitor): +class CloningVisitor(Visitor[Any]): """Visitor that traverses and clones the AST tree. Clears types.""" def __init__(