Skip to content

Commit

Permalink
Fix session replay joining with v2
Browse files Browse the repository at this point in the history
  • Loading branch information
robbie-c committed Jun 19, 2024
1 parent c45fdc3 commit af276d7
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 29 deletions.
25 changes: 21 additions & 4 deletions posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from posthog.hogql.database.schema.session_replay_events import (
RawSessionReplayEventsTable,
SessionReplayEventsTable,
join_replay_table_to_sessions_table_v2,
)
from posthog.hogql.database.schema.sessions_v1 import RawSessionsTableV1, SessionsTableV1
from posthog.hogql.database.schema.sessions_v2 import (
Expand Down Expand Up @@ -242,13 +243,29 @@ def create_hogql_database(
)

if modifiers.sessionTableVersion in [SessionTableVersion.V2]:
database.sessions = SessionsTableV2()
database.raw_sessions = RawSessionsTableV2()
database.events.fields["session"] = LazyJoin(
raw_sessions = RawSessionsTableV2()
database.raw_sessions = raw_sessions
sessions = SessionsTableV2()
events = database.events
events.fields["session"] = LazyJoin(
from_field=["$session_id"],
join_table=SessionsTableV2(),
join_table=sessions,
join_function=join_events_table_to_sessions_table_v2,
)
replay_events = database.session_replay_events
replay_events.fields["session"] = LazyJoin(
from_field=["session_id"],
join_table=sessions,
join_function=join_replay_table_to_sessions_table_v2,
)
replay_events.fields["events"].join_table = events
raw_replay_events = database.raw_session_replay_events
raw_replay_events.fields["session"] = LazyJoin(
from_field=["session_id"],
join_table=sessions,
join_function=join_replay_table_to_sessions_table_v2,
)
raw_replay_events.fields["events"].join_table = events

database.persons.fields["$virt_initial_referring_domain_type"] = create_initial_domain_type(
"$virt_initial_referring_domain_type"
Expand Down
32 changes: 28 additions & 4 deletions posthog/hogql/database/schema/session_replay_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@
)
from datetime import datetime

from posthog.hogql.database.schema.sessions_v1 import SessionsTableV1, select_from_sessions_table
from posthog.hogql.database.schema.sessions_v1 import SessionsTableV1, select_from_sessions_table_v1
from posthog.hogql.database.schema.sessions_v2 import select_from_sessions_table_v2, session_id_to_session_id_v7_expr

from posthog.hogql.errors import ResolutionError


def join_replay_table_to_sessions_table(
def join_replay_table_to_sessions_table_v1(
join_to_add: LazyJoinToAdd, context: HogQLContext, node: SelectQuery
) -> JoinExpr:
from posthog.hogql import ast

if not join_to_add.fields_accessed:
raise ResolutionError("No fields requested from replay")

join_expr = ast.JoinExpr(table=select_from_sessions_table(join_to_add.fields_accessed, node, context))
join_expr = ast.JoinExpr(table=select_from_sessions_table_v1(join_to_add.fields_accessed, node, context))
join_expr.join_type = "LEFT JOIN"
join_expr.alias = join_to_add.to_table
join_expr.constraint = ast.JoinConstraint(
Expand All @@ -47,6 +49,28 @@ def join_replay_table_to_sessions_table(
return join_expr


def join_replay_table_to_sessions_table_v2(
join_to_add: LazyJoinToAdd, context: HogQLContext, node: SelectQuery
) -> JoinExpr:
from posthog.hogql import ast

if not join_to_add.fields_accessed:
raise ResolutionError("No fields requested from replay")

join_expr = ast.JoinExpr(table=select_from_sessions_table_v2(join_to_add.fields_accessed, node, context))
join_expr.join_type = "LEFT JOIN"
join_expr.alias = join_to_add.to_table
join_expr.constraint = ast.JoinConstraint(
expr=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=session_id_to_session_id_v7_expr(ast.Field(chain=[join_to_add.from_table, "session_id"])),
right=ast.Field(chain=[join_to_add.to_table, "session_id_v7"]),
),
constraint_type="ON",
)
return join_expr


def join_with_events_table(
join_to_add: LazyJoinToAdd,
context: HogQLContext,
Expand Down Expand Up @@ -177,7 +201,7 @@ def join_with_console_logs_log_entries_table(
"session": LazyJoin(
from_field=["session_id"],
join_table=SessionsTableV1(),
join_function=join_replay_table_to_sessions_table,
join_function=join_replay_table_to_sessions_table_v1,
),
}

Expand Down
6 changes: 3 additions & 3 deletions posthog/hogql/database/schema/sessions_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def avoid_asterisk_fields(self) -> list[str]:
]


def select_from_sessions_table(
def select_from_sessions_table_v1(
requested_fields: dict[str, list[str | int]], node: ast.SelectQuery, context: HogQLContext
):
from posthog.hogql import ast
Expand Down Expand Up @@ -271,7 +271,7 @@ def lazy_select(
context,
node: ast.SelectQuery,
):
return select_from_sessions_table(table_to_add.fields_accessed, node, context)
return select_from_sessions_table_v1(table_to_add.fields_accessed, node, context)

def to_printed_clickhouse(self, context):
return "sessions"
Expand All @@ -293,7 +293,7 @@ def join_events_table_to_sessions_table(
if not join_to_add.fields_accessed:
raise ResolutionError("No fields requested from events")

join_expr = ast.JoinExpr(table=select_from_sessions_table(join_to_add.fields_accessed, node, context))
join_expr = ast.JoinExpr(table=select_from_sessions_table_v1(join_to_add.fields_accessed, node, context))
join_expr.join_type = "LEFT JOIN"
join_expr.alias = join_to_add.to_table
join_expr.constraint = ast.JoinConstraint(
Expand Down
18 changes: 11 additions & 7 deletions posthog/hogql/database/schema/sessions_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def avoid_asterisk_fields(self) -> list[str]:
]


def select_from_sessions_table(
def select_from_sessions_table_v2(
requested_fields: dict[str, list[str | int]], node: ast.SelectQuery, context: HogQLContext
):
from posthog.hogql import ast
Expand Down Expand Up @@ -293,7 +293,7 @@ def lazy_select(
context,
node: ast.SelectQuery,
):
return select_from_sessions_table(table_to_add.fields_accessed, node, context)
return select_from_sessions_table_v2(table_to_add.fields_accessed, node, context)

def to_printed_clickhouse(self, context):
return "sessions"
Expand All @@ -308,6 +308,13 @@ def avoid_asterisk_fields(self) -> list[str]:
]


def session_id_to_session_id_v7_expr(session_id: ast.Expr) -> ast.Expr:
return ast.Call(
name="_toUInt128",
args=[ast.Call(name="toUUID", args=[session_id])],
)


def join_events_table_to_sessions_table_v2(
join_to_add: LazyJoinToAdd, context: HogQLContext, node: ast.SelectQuery
) -> ast.JoinExpr:
Expand All @@ -316,16 +323,13 @@ def join_events_table_to_sessions_table_v2(
if not join_to_add.fields_accessed:
raise ResolutionError("No fields requested from events")

join_expr = ast.JoinExpr(table=select_from_sessions_table(join_to_add.fields_accessed, node, context))
join_expr = ast.JoinExpr(table=select_from_sessions_table_v2(join_to_add.fields_accessed, node, context))
join_expr.join_type = "LEFT JOIN"
join_expr.alias = join_to_add.to_table
join_expr.constraint = ast.JoinConstraint(
expr=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Call(
name="_toUInt128",
args=[ast.Call(name="toUUID", args=[ast.Field(chain=[join_to_add.from_table, "$session_id"])])],
),
left=session_id_to_session_id_v7_expr(ast.Field(chain=[join_to_add.from_table, "$session_id"])),
right=ast.Field(chain=[join_to_add.to_table, "session_id_v7"]),
),
constraint_type="ON",
Expand Down
27 changes: 16 additions & 11 deletions posthog/hogql/database/schema/test/test_session_replay_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from posthog.hogql.parser import parse_select
from posthog.hogql.query import execute_hogql_query
from posthog.models.event.sql import TRUNCATE_EVENTS_TABLE_SQL
from posthog.models.utils import uuid7
from posthog.schema import HogQLQueryModifiers
from posthog.session_recordings.queries.test.session_replay_sql import produce_replay_summary
from posthog.session_recordings.sql.session_replay_event_sql import TRUNCATE_SESSION_REPLAY_EVENTS_TABLE_SQL
Expand All @@ -23,6 +24,10 @@

@freeze_time("2021-01-01T13:46:23")
class TestFilterSessionReplaysBySessions(ClickhouseTestMixin, APIBaseTest):
session_with_one_hour = str(uuid7())
session_with_different_session_and_replay_duration = str(uuid7())
session_with_no_events = str(uuid7())

def setUp(self):
super().setUp()

Expand All @@ -33,29 +38,29 @@ def setUp(self):
produce_replay_summary(
team_id=self.team.pk,
distinct_id="d1",
session_id="session_with_one_hour",
session_id=self.session_with_one_hour,
)

_create_event(
event="$pageview",
team=self.team,
distinct_id="d1",
properties={"$current_url": "https://example.com", "$session_id": "session_with_one_hour"},
properties={"$current_url": "https://example.com", "$session_id": self.session_with_one_hour},
)

_create_event(
event="$pageview",
team=self.team,
distinct_id="d1",
properties={"$current_url": "https://example.com", "$session_id": "session_with_one_hour"},
properties={"$current_url": "https://example.com", "$session_id": self.session_with_one_hour},
timestamp=now() + timedelta(hours=1),
)

# 1-hour session replay
produce_replay_summary(
team_id=self.team.pk,
distinct_id="d1",
session_id="session_with_different_session_and_replay_duration",
session_id=self.session_with_different_session_and_replay_duration,
)

_create_event(
Expand All @@ -64,7 +69,7 @@ def setUp(self):
distinct_id="d1",
properties={
"$current_url": "https://different.com",
"$session_id": "session_with_different_session_and_replay_duration",
"$session_id": self.session_with_different_session_and_replay_duration,
},
)

Expand All @@ -74,14 +79,14 @@ def setUp(self):
distinct_id="d1",
properties={
"$current_url": "https://different.com",
"$session_id": "session_with_different_session_and_replay_duration",
"$session_id": self.session_with_different_session_and_replay_duration,
},
# timestamp is two hours in the future
timestamp=now() + timedelta(hours=2),
)

produce_replay_summary(
team_id=self.team.pk, distinct_id="d1", session_id="session_with_no_events", log_messages=None
team_id=self.team.pk, distinct_id="d1", session_id=self.session_with_no_events, log_messages=None
)

@snapshot_clickhouse_queries
Expand All @@ -100,9 +105,9 @@ def test_select_by_duration_without_session_filter(self):
)

assert response.results == [
("session_with_different_session_and_replay_duration",),
("session_with_no_events",),
("session_with_one_hour",),
(self.session_with_different_session_and_replay_duration,),
(self.session_with_no_events,),
(self.session_with_one_hour,),
]

@snapshot_clickhouse_queries
Expand All @@ -122,7 +127,7 @@ def test_select_by_duration_with_session_duration_filter(self):
)

assert response.results == [
("session_with_different_session_and_replay_duration",),
(self.session_with_different_session_and_replay_duration,),
]


Expand Down
15 changes: 15 additions & 0 deletions posthog/hogql/database/schema/test/test_sessions_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ def test_select_event_sessions_star(self):
1,
)

def test_select_session_replay_session_duration(self):
session_id = str(uuid7())

response = self.__execute(
parse_select(
"select raw_session_replay_events.session.duration from raw_session_replay_events",
placeholders={"session_id": ast.Constant(value=session_id)},
),
)

self.assertEqual(
len(response.results or []),
0, # just making sure the query runs
)

def test_channel_type(self):
session_id = "session_test_channel_type"

Expand Down
15 changes: 15 additions & 0 deletions posthog/hogql/database/schema/test/test_sessions_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ def test_select_event_sessions_star(self):
1,
)

def test_select_session_replay_session_duration(self):
session_id = str(uuid7())

response = self.__execute(
parse_select(
"select raw_session_replay_events.session.duration from raw_session_replay_events",
placeholders={"session_id": ast.Constant(value=session_id)},
),
)

self.assertEqual(
len(response.results or []),
0, # just making sure the query runs
)

def test_channel_type(self):
session_id = str(uuid7())

Expand Down

0 comments on commit af276d7

Please sign in to comment.