Skip to content

Commit

Permalink
feat: more manual join removal (#22893)
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldambra authored Jun 12, 2024
1 parent 2e60b33 commit 46eb8fc
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 194 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any, NamedTuple, cast, Optional
from datetime import datetime, timedelta

Expand All @@ -6,14 +7,26 @@
from posthog.hogql.parser import parse_select
from posthog.hogql.property import entity_to_expr, property_to_expr
from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator
from posthog.models import Team
from posthog.models import Team, Property
from posthog.models.filters.session_recordings_filter import SessionRecordingsFilter
from posthog.models.filters.mixins.utils import cached_property
from posthog.models.property import PropertyGroup
from posthog.schema import QueryTiming, HogQLQueryModifiers
from posthog.session_recordings.queries.session_replay_events import ttl_days
from posthog.constants import TREND_FILTER_TYPE_ACTIONS, PropertyOperatorType

import structlog

logger = structlog.get_logger(__name__)


def is_event_property(p: Property) -> bool:
return p.type == "event" or (p.type == "hogql" and bool(re.search(r"(?<!person\.)properties\.", p.key)))


def is_person_property(p: Property) -> bool:
return p.type == "person" or (p.type == "hogql" and "person.properties" in p.key)


class SessionRecordingQueryResult(NamedTuple):
results: list
Expand Down Expand Up @@ -158,8 +171,8 @@ def _where_predicates(self) -> ast.And:
)
)

if self._filter.entities:
events_sub_query = EventsSubQuery(self._team, self._filter, self.ttl_days).get_query()
events_sub_query = EventsSubQuery(self._team, self._filter, self.ttl_days).get_query()
if events_sub_query:
exprs.append(
ast.CompareOperation(
op=ast.CompareOperationOp.In,
Expand All @@ -179,9 +192,12 @@ def _where_predicates(self) -> ast.And:
)
)

non_person_properties = self._strip_person_properties(self._filter.property_groups)
if non_person_properties:
exprs.append(property_to_expr(non_person_properties, team=self._team, scope="replay"))
remaining_properties = self._strip_person_and_event_properties(self._filter.property_groups)
if remaining_properties:
logger.info(
"session_replay_query_builder has unhandled properties", unhandled_properties=remaining_properties
)
exprs.append(property_to_expr(remaining_properties, team=self._team, scope="replay"))

person_id_subquery = PersonsIdSubQuery(self._team, self._filter, self.ttl_days).get_query()
if person_id_subquery:
Expand Down Expand Up @@ -254,8 +270,10 @@ def _having_predicates(self) -> ast.And | Constant:

return ast.And(exprs=exprs) if exprs else Constant(value=True)

def _strip_person_properties(self, property_group: PropertyGroup) -> PropertyGroup | None:
property_groups_to_keep = [g for g in property_group.flat if g.type != "person"]
def _strip_person_and_event_properties(self, property_group: PropertyGroup) -> PropertyGroup | None:
property_groups_to_keep = [
g for g in property_group.flat if not is_event_property(g) and not is_person_property(g)
]

return (
PropertyGroup(
Expand Down Expand Up @@ -294,7 +312,7 @@ def get_query(self) -> ast.SelectQuery | ast.SelectUnionQuery | None:

@cached_property
def person_properties(self) -> PropertyGroup | None:
person_property_groups = [g for g in self._filter.property_groups.flat if g.type == "person" in g.type]
person_property_groups = [g for g in self._filter.property_groups.flat if is_person_property(g)]
return (
PropertyGroup(
type=PropertyOperatorType.AND,
Expand Down Expand Up @@ -340,7 +358,7 @@ def get_query(self) -> ast.SelectQuery | ast.SelectUnionQuery | None:

@cached_property
def person_properties(self) -> PropertyGroup | None:
person_property_groups = [g for g in self._filter.property_groups.flat if g.type == "person" in g.type]
person_property_groups = [g for g in self._filter.property_groups.flat if is_person_property(g)]
return (
PropertyGroup(
type=PropertyOperatorType.AND,
Expand Down Expand Up @@ -392,14 +410,17 @@ def _event_predicates(self):

return event_exprs, list(event_names)

def get_query(self):
return ast.SelectQuery(
select=[ast.Alias(alias="session_id", expr=ast.Field(chain=["$session_id"]))],
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
where=self._where_predicates(),
having=self._having_predicates(),
group_by=[ast.Field(chain=["$session_id"])],
)
def get_query(self) -> ast.SelectQuery | ast.SelectUnionQuery | None:
if self._filter.entities or self.event_properties:
return ast.SelectQuery(
select=[ast.Alias(alias="session_id", expr=ast.Field(chain=["$session_id"]))],
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
where=self._where_predicates(),
having=self._having_predicates(),
group_by=[ast.Field(chain=["$session_id"])],
)
else:
return None

def _where_predicates(self) -> ast.Expr:
exprs: list[ast.Expr] = [
Expand Down Expand Up @@ -445,6 +466,9 @@ def _where_predicates(self) -> ast.Expr:
if event_where_exprs:
exprs.append(ast.Or(exprs=event_where_exprs))

if self.event_properties:
exprs.append(property_to_expr(self.event_properties, team=self._team, scope="replay"))

if self._filter.session_ids:
exprs.append(
ast.CompareOperation(
Expand All @@ -470,3 +494,7 @@ def _having_predicates(self) -> ast.Expr:
)

return ast.Constant(value=True)

@cached_property
def event_properties(self):
return [g for g in self._filter.property_groups.flat if is_event_property(g)]
Loading

0 comments on commit 46eb8fc

Please sign in to comment.