Skip to content

Commit

Permalink
Push down figuring group_type_index
Browse files Browse the repository at this point in the history
  • Loading branch information
webjunkie committed Dec 17, 2023
1 parent e27bb5e commit a2bf95a
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 26 deletions.
3 changes: 3 additions & 0 deletions posthog/hogql_queries/actor_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


class ActorStrategy:
field: str
origin: str
origin_id: str

Expand All @@ -29,6 +30,7 @@ def filter_conditions(self) -> List[ast.Expr]:


class PersonStrategy(ActorStrategy):
field = "person"
origin = "persons"
origin_id = "id"

Expand Down Expand Up @@ -88,6 +90,7 @@ def filter_conditions(self) -> List[ast.Expr]:


class GroupStrategy(ActorStrategy):
field = "group"
origin = "groups"
origin_id = "key"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def to_query(self) -> ast.SelectQuery | ast.SelectUnionQuery:
def to_persons_query(self) -> ast.SelectQuery | ast.SelectUnionQuery:
return self.to_query()

@property
def group_type_index(self) -> int | None:
if not self.source_runner or not isinstance(self.source_runner, RetentionQueryRunner):
return None

return cast(RetentionQueryRunner, self.source_runner).group_type_index

def calculate(self) -> HogQLQueryResponse:
return execute_hogql_query(
query_type="InsightPersonsQuery",
Expand Down
8 changes: 6 additions & 2 deletions posthog/hogql_queries/insights/retention_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def __init__(
):
super().__init__(query, team=team, timings=timings, modifiers=modifiers, limit_context=limit_context)

@property
def group_type_index(self) -> int | None:
return self.query.aggregation_group_type_index

def get_applicable_entity(self, event_query_type):
default_entity = RetentionEntity(
**{
Expand All @@ -71,8 +75,8 @@ def retention_events_query(self, event_query_type) -> ast.SelectQuery:
event_date_expr = start_of_interval_sql

target_field = "person_id"
if self.query.aggregation_group_type_index is not None:
group_index = int(self.query.aggregation_group_type_index)
if self.group_type_index is not None:
group_index = int(self.group_type_index)
if 0 <= group_index <= 4:
target_field = f"$group_{group_index}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_actors_query(self, selected_interval, query):
runner = PersonsQueryRunner(
team=self.team,
query={
"select": ["actor", "appearances"],
"select": ["person", "appearances"],
"orderBy": ["appearances_count DESC", "actor_id"],
"source": {
"kind": "InsightPersonsQuery",
Expand Down
40 changes: 17 additions & 23 deletions posthog/hogql_queries/persons_query_runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from datetime import timedelta
from typing import List, cast, Literal, Generator, Sequence, Iterator
from typing import List, cast, Literal, Generator, Sequence, Iterator, Optional
from posthog.hogql import ast
from posthog.hogql.constants import get_max_limit_for_context, get_default_limit_for_context
from posthog.hogql.parser import parse_expr, parse_order_expr
from posthog.hogql.property import has_aggregation
from posthog.hogql_queries.actor_strategies import ActorStrategy, PersonStrategy, GroupStrategy
from posthog.hogql_queries.insights.insight_persons_query_runner import InsightPersonsQueryRunner
from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator
from posthog.hogql_queries.query_runner import QueryRunner, get_query_runner
from posthog.schema import PersonsQuery, PersonsQueryResponse, InsightPersonsQuery, StickinessQuery, LifecycleQuery
from posthog.schema import PersonsQuery, PersonsQueryResponse


class PersonsQueryRunner(QueryRunner):
Expand All @@ -17,27 +18,23 @@ class PersonsQueryRunner(QueryRunner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.paginator = HogQLHasMorePaginator(limit=self.query_limit(), offset=self.query.offset or 0)
self.source_query_runner: Optional[QueryRunner] = None

if self.query.source:
self.source_query_runner = get_query_runner(self.query.source, self.team, self.timings, self.limit_context)

self.strategy = self.determine_strategy()

@property
def aggregation_group_type_index(self) -> int | None:
if (
not self.query.source
or not isinstance(self.query.source, InsightPersonsQuery)
or isinstance(self.query.source.source, StickinessQuery)
or isinstance(self.query.source.source, LifecycleQuery)
):
return None
try:
return self.query.source.source.aggregation_group_type_index
except AttributeError:
def group_type_index(self) -> int | None:
if not self.source_query_runner or not isinstance(self.source_query_runner, InsightPersonsQueryRunner):
return None

return self.source_query_runner.group_type_index

def determine_strategy(self) -> ActorStrategy:
if self.aggregation_group_type_index is not None:
return GroupStrategy(
self.aggregation_group_type_index, team=self.team, query=self.query, paginator=self.paginator
)
if self.group_type_index is not None:
return GroupStrategy(self.group_type_index, team=self.team, query=self.query, paginator=self.paginator)
return PersonStrategy(team=self.team, query=self.query, paginator=self.paginator)

def enrich_with_actors(self, results, actor_column_index, actors_lookup) -> Generator[List, None, None]:
Expand Down Expand Up @@ -97,8 +94,7 @@ def source_id_column(self, source_query: ast.SelectQuery) -> List[str]:

def source_table_join(self) -> ast.JoinExpr:
assert self.query.source is not None # For type checking
source_query_runner = get_query_runner(self.query.source, self.team, self.timings)
source_query = source_query_runner.to_persons_query()
source_query = self.source_query_runner.to_persons_query()
source_id_chain = self.source_id_column(source_query)
source_alias = "source"

Expand Down Expand Up @@ -126,10 +122,8 @@ def to_query(self) -> ast.SelectQuery:
for expr in self.input_columns():
if expr == "person.$delete":
column = ast.Constant(value=1)
elif expr == "person" or (expr == "actor" and self.aggregation_group_type_index is None):
column = ast.Field(chain=["id"])
elif expr == "group" or (expr == "actor" and self.aggregation_group_type_index is not None):
column = ast.Field(chain=["key"])
elif expr == self.strategy.field:
column = ast.Field(chain=[self.strategy.origin_id])
else:
column = parse_expr(expr)
columns.append(column)
Expand Down

0 comments on commit a2bf95a

Please sign in to comment.