diff --git a/posthog/hogql/database/schema/session_replay_events.py b/posthog/hogql/database/schema/session_replay_events.py index 30a9e1f1acea2..214e3379fb2d8 100644 --- a/posthog/hogql/database/schema/session_replay_events.py +++ b/posthog/hogql/database/schema/session_replay_events.py @@ -33,7 +33,6 @@ def join_replay_table_to_sessions_table( if not join_to_add.fields_accessed: raise ResolutionError("No fields requested from replay") - # TODO i think this should be fixed in the where_clause_extractor so that it grabs time bounds for us join_expr = ast.JoinExpr(table=select_from_sessions_table(join_to_add.fields_accessed, node, context)) join_expr.join_type = "LEFT JOIN" join_expr.alias = join_to_add.to_table diff --git a/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py index f9213dd237190..3ae4e408c2959 100644 --- a/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py +++ b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py @@ -458,3 +458,36 @@ def test_session_breakdown(self): breakdown_value LIMIT {MAX_SELECT_RETURNED_ROWS}""" assert expected == actual + + def test_session_replay_query(self): + actual = self.print_query( + """ +SELECT + s.session_id, + min(s.min_first_timestamp) as start_time +FROM raw_session_replay_events s +WHERE s.session.$entry_pathname = '/home' AND min_first_timestamp >= '2021-01-01:12:34' AND min_first_timestamp < now() +GROUP BY session_id + """ + ) + expected = f"""SELECT + s.session_id AS session_id, + min(toTimeZone(s.min_first_timestamp, %(hogql_val_5)s)) AS start_time +FROM + session_replay_events AS s + LEFT JOIN (SELECT + path(nullIf(argMinMerge(sessions.entry_url), %(hogql_val_0)s)) AS `$entry_pathname`, + sessions.session_id AS session_id + FROM + sessions + WHERE + and(equals(sessions.team_id, {self.team.id}), ifNull(greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_1)s), toIntervalDay(3)), %(hogql_val_2)s), 0), ifNull(lessOrEquals(minus(toTimeZone(sessions.min_timestamp, %(hogql_val_3)s), toIntervalDay(3)), now64(6, %(hogql_val_4)s)), 0)) + GROUP BY + sessions.session_id, + sessions.session_id) AS s__session ON equals(s.session_id, s__session.session_id) +WHERE + and(equals(s.team_id, {self.team.id}), ifNull(equals(s__session.`$entry_pathname`, %(hogql_val_6)s), 0), ifNull(greaterOrEquals(toTimeZone(s.min_first_timestamp, %(hogql_val_7)s), %(hogql_val_8)s), 0), ifNull(less(toTimeZone(s.min_first_timestamp, %(hogql_val_9)s), now64(6, %(hogql_val_10)s)), 0)) +GROUP BY + s.session_id +LIMIT 50000""" + self.assertEqual(expected, actual) diff --git a/posthog/hogql/database/schema/util/where_clause_extractor.py b/posthog/hogql/database/schema/util/where_clause_extractor.py index 7cb413960ca80..4f0096af9ff53 100644 --- a/posthog/hogql/database/schema/util/where_clause_extractor.py +++ b/posthog/hogql/database/schema/util/where_clause_extractor.py @@ -441,9 +441,9 @@ def visit_field(self, node: ast.Field) -> bool: if node.type and isinstance(node.type, ast.FieldType): resolved_field = node.type.resolve_database_field(self.context) if resolved_field and isinstance(resolved_field, DatabaseField) and resolved_field: - return resolved_field.name in ["$start_timestamp", "min_timestamp", "timestamp"] + return resolved_field.name in ["$start_timestamp", "min_timestamp", "timestamp", "min_first_timestamp"] # no type information, so just use the name of the field - return node.chain[-1] in ["$start_timestamp", "min_timestamp", "timestamp"] + return node.chain[-1] in ["$start_timestamp", "min_timestamp", "timestamp", "min_first_timestamp"] def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> bool: # only allow the min_timestamp field to be used on one side of the arithmetic operation @@ -498,6 +498,7 @@ def visit_placeholder(self, node: ast.Placeholder) -> bool: def visit_alias(self, node: ast.Alias) -> bool: from posthog.hogql.database.schema.events import EventsTable from posthog.hogql.database.schema.sessions import SessionsTable + from posthog.hogql.database.schema.session_replay_events import RawSessionReplayEventsTable if node.type and isinstance(node.type, ast.FieldAliasType): resolved_field = node.type.resolve_database_field(self.context) @@ -507,13 +508,21 @@ def visit_alias(self, node: ast.Alias) -> bool: if isinstance(table_type, ast.TableAliasType): table_type = table_type.table_type return ( - isinstance(table_type, ast.TableType) - and isinstance(table_type.table, EventsTable) - and resolved_field.name == "timestamp" - ) or ( - isinstance(table_type, ast.LazyTableType) - and isinstance(table_type.table, SessionsTable) - and resolved_field.name == "$start_timestamp" + ( + isinstance(table_type, ast.TableType) + and isinstance(table_type.table, EventsTable) + and resolved_field.name == "timestamp" + ) + or ( + isinstance(table_type, ast.LazyTableType) + and isinstance(table_type.table, SessionsTable) + and resolved_field.name == "$start_timestamp" + ) + or ( + isinstance(table_type, ast.TableType) + and isinstance(table_type.table, RawSessionReplayEventsTable) + and resolved_field.name == "min_first_timestamp" + ) ) return self.visit(node.expr) @@ -536,6 +545,7 @@ def __init__(self, context: HogQLContext, *args, **kwargs): def visit_field(self, node: ast.Field) -> ast.Field: from posthog.hogql.database.schema.events import EventsTable from posthog.hogql.database.schema.sessions import SessionsTable + from posthog.hogql.database.schema.session_replay_events import RawSessionReplayEventsTable if node.type and isinstance(node.type, ast.FieldType): resolved_field = node.type.resolve_database_field(self.context) @@ -544,12 +554,14 @@ def visit_field(self, node: ast.Field) -> ast.Field: table_type = table_type.table_type table = table_type.table if resolved_field and isinstance(resolved_field, DatabaseField): - if (isinstance(table, EventsTable) and resolved_field.name == "timestamp") or ( - isinstance(table, SessionsTable) and resolved_field.name == "$start_timestamp" + if ( + (isinstance(table, EventsTable) and resolved_field.name == "timestamp") + or (isinstance(table, SessionsTable) and resolved_field.name == "$start_timestamp") + or (isinstance(table, RawSessionReplayEventsTable) and resolved_field.name == "min_first_timestamp") ): return ast.Field(chain=["raw_sessions", "min_timestamp"]) # no type information, so just use the name of the field - if node.chain[-1] in ["$start_timestamp", "min_timestamp", "timestamp"]: + if node.chain[-1] in ["$start_timestamp", "min_timestamp", "timestamp", "min_first_timestamp"]: return ast.Field(chain=["raw_sessions", "min_timestamp"]) return node