Skip to content

Commit

Permalink
feat(session-replay): Make session replay support where clause extrac…
Browse files Browse the repository at this point in the history
…tor (#22917)

* Make session replay support where clause extractor

* Remove TODO

* Add test
  • Loading branch information
robbie-c authored Jun 12, 2024
1 parent 50b1cca commit c7ba0ff
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 13 deletions.
1 change: 0 additions & 1 deletion posthog/hogql/database/schema/session_replay_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def join_replay_table_to_sessions_table(
if not join_to_add.fields_accessed:
raise ResolutionError("No fields requested from replay")

# TODO i think this should be fixed in the where_clause_extractor so that it grabs time bounds for us
join_expr = ast.JoinExpr(table=select_from_sessions_table(join_to_add.fields_accessed, node, context))
join_expr.join_type = "LEFT JOIN"
join_expr.alias = join_to_add.to_table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,36 @@ def test_session_breakdown(self):
breakdown_value
LIMIT {MAX_SELECT_RETURNED_ROWS}"""
assert expected == actual

def test_session_replay_query(self):
actual = self.print_query(
"""
SELECT
s.session_id,
min(s.min_first_timestamp) as start_time
FROM raw_session_replay_events s
WHERE s.session.$entry_pathname = '/home' AND min_first_timestamp >= '2021-01-01:12:34' AND min_first_timestamp < now()
GROUP BY session_id
"""
)
expected = f"""SELECT
s.session_id AS session_id,
min(toTimeZone(s.min_first_timestamp, %(hogql_val_5)s)) AS start_time
FROM
session_replay_events AS s
LEFT JOIN (SELECT
path(nullIf(argMinMerge(sessions.entry_url), %(hogql_val_0)s)) AS `$entry_pathname`,
sessions.session_id AS session_id
FROM
sessions
WHERE
and(equals(sessions.team_id, {self.team.id}), ifNull(greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_1)s), toIntervalDay(3)), %(hogql_val_2)s), 0), ifNull(lessOrEquals(minus(toTimeZone(sessions.min_timestamp, %(hogql_val_3)s), toIntervalDay(3)), now64(6, %(hogql_val_4)s)), 0))
GROUP BY
sessions.session_id,
sessions.session_id) AS s__session ON equals(s.session_id, s__session.session_id)
WHERE
and(equals(s.team_id, {self.team.id}), ifNull(equals(s__session.`$entry_pathname`, %(hogql_val_6)s), 0), ifNull(greaterOrEquals(toTimeZone(s.min_first_timestamp, %(hogql_val_7)s), %(hogql_val_8)s), 0), ifNull(less(toTimeZone(s.min_first_timestamp, %(hogql_val_9)s), now64(6, %(hogql_val_10)s)), 0))
GROUP BY
s.session_id
LIMIT 50000"""
self.assertEqual(expected, actual)
36 changes: 24 additions & 12 deletions posthog/hogql/database/schema/util/where_clause_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ def visit_field(self, node: ast.Field) -> bool:
if node.type and isinstance(node.type, ast.FieldType):
resolved_field = node.type.resolve_database_field(self.context)
if resolved_field and isinstance(resolved_field, DatabaseField) and resolved_field:
return resolved_field.name in ["$start_timestamp", "min_timestamp", "timestamp"]
return resolved_field.name in ["$start_timestamp", "min_timestamp", "timestamp", "min_first_timestamp"]
# no type information, so just use the name of the field
return node.chain[-1] in ["$start_timestamp", "min_timestamp", "timestamp"]
return node.chain[-1] in ["$start_timestamp", "min_timestamp", "timestamp", "min_first_timestamp"]

def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> bool:
# only allow the min_timestamp field to be used on one side of the arithmetic operation
Expand Down Expand Up @@ -498,6 +498,7 @@ def visit_placeholder(self, node: ast.Placeholder) -> bool:
def visit_alias(self, node: ast.Alias) -> bool:
from posthog.hogql.database.schema.events import EventsTable
from posthog.hogql.database.schema.sessions import SessionsTable
from posthog.hogql.database.schema.session_replay_events import RawSessionReplayEventsTable

if node.type and isinstance(node.type, ast.FieldAliasType):
resolved_field = node.type.resolve_database_field(self.context)
Expand All @@ -507,13 +508,21 @@ def visit_alias(self, node: ast.Alias) -> bool:
if isinstance(table_type, ast.TableAliasType):
table_type = table_type.table_type
return (
isinstance(table_type, ast.TableType)
and isinstance(table_type.table, EventsTable)
and resolved_field.name == "timestamp"
) or (
isinstance(table_type, ast.LazyTableType)
and isinstance(table_type.table, SessionsTable)
and resolved_field.name == "$start_timestamp"
(
isinstance(table_type, ast.TableType)
and isinstance(table_type.table, EventsTable)
and resolved_field.name == "timestamp"
)
or (
isinstance(table_type, ast.LazyTableType)
and isinstance(table_type.table, SessionsTable)
and resolved_field.name == "$start_timestamp"
)
or (
isinstance(table_type, ast.TableType)
and isinstance(table_type.table, RawSessionReplayEventsTable)
and resolved_field.name == "min_first_timestamp"
)
)

return self.visit(node.expr)
Expand All @@ -536,6 +545,7 @@ def __init__(self, context: HogQLContext, *args, **kwargs):
def visit_field(self, node: ast.Field) -> ast.Field:
from posthog.hogql.database.schema.events import EventsTable
from posthog.hogql.database.schema.sessions import SessionsTable
from posthog.hogql.database.schema.session_replay_events import RawSessionReplayEventsTable

if node.type and isinstance(node.type, ast.FieldType):
resolved_field = node.type.resolve_database_field(self.context)
Expand All @@ -544,12 +554,14 @@ def visit_field(self, node: ast.Field) -> ast.Field:
table_type = table_type.table_type
table = table_type.table
if resolved_field and isinstance(resolved_field, DatabaseField):
if (isinstance(table, EventsTable) and resolved_field.name == "timestamp") or (
isinstance(table, SessionsTable) and resolved_field.name == "$start_timestamp"
if (
(isinstance(table, EventsTable) and resolved_field.name == "timestamp")
or (isinstance(table, SessionsTable) and resolved_field.name == "$start_timestamp")
or (isinstance(table, RawSessionReplayEventsTable) and resolved_field.name == "min_first_timestamp")
):
return ast.Field(chain=["raw_sessions", "min_timestamp"])
# no type information, so just use the name of the field
if node.chain[-1] in ["$start_timestamp", "min_timestamp", "timestamp"]:
if node.chain[-1] in ["$start_timestamp", "min_timestamp", "timestamp", "min_first_timestamp"]:
return ast.Field(chain=["raw_sessions", "min_timestamp"])
return node

Expand Down

0 comments on commit c7ba0ff

Please sign in to comment.