Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: wip #23096

Closed
wants to merge 17 commits into from
4 changes: 2 additions & 2 deletions posthog/api/test/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class TestCohort(TestExportMixin, ClickhouseTestMixin, APIBaseTest, QueryMatchingTest):
# select all queries for snapshots
def capture_select_queries(self):
return self.capture_queries(("INSERT INTO cohortpeople", "SELECT", "ALTER", "select", "DELETE"))
return self.capture_queries_startswith(("INSERT INTO cohortpeople", "SELECT", "ALTER", "select", "DELETE"))

def _get_cohort_activity(
self,
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_creating_update_and_calculating(self, patch_sync_execute, patch_calcula
},
)

with self.capture_queries("INSERT INTO cohortpeople") as insert_statements:
with self.capture_queries_startswith("INSERT INTO cohortpeople") as insert_statements:
response = self.client.patch(
f"/api/projects/{self.team.id}/cohorts/{response.json()['id']}",
data={
Expand Down
2 changes: 2 additions & 0 deletions posthog/hogql/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def get_breakdown_limit_for_context(limit_context: LimitContext) -> int:
class HogQLQuerySettings(BaseModel):
model_config = ConfigDict(extra="forbid")
optimize_aggregation_in_order: Optional[bool] = None
use_query_cache: Optional[bool] = None
query_cache_ttl: Optional[int] = None


# Settings applied on top of all HogQL queries.
Expand Down
36 changes: 31 additions & 5 deletions posthog/hogql/database/schema/person_distinct_ids.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from posthog.hogql.ast import SelectQuery
from posthog.hogql import ast
from posthog.hogql.ast import SelectQuery, And
from posthog.hogql.context import HogQLContext

from posthog.hogql.database.argmax import argmax_select
Expand All @@ -15,6 +16,7 @@
)
from posthog.hogql.database.schema.persons import join_with_persons_table
from posthog.hogql.errors import ResolutionError
from posthog.hogql.visitor import clone_expr

PERSON_DISTINCT_IDS_FIELDS = {
"team_id": IntegerDatabaseField(name="team_id"),
Expand All @@ -28,18 +30,42 @@
}


def select_from_person_distinct_ids_table(requested_fields: dict[str, list[str | int]]):
def select_from_person_distinct_ids_table(
requested_fields: dict[str, list[str | int]], context: HogQLContext, node: SelectQuery
):
# Always include "person_id", as it's the key we use to make further joins, and it'd be great if it's available
if "person_id" not in requested_fields:
requested_fields = {**requested_fields, "person_id": ["person_id"]}
return argmax_select(
select = argmax_select(
table_name="raw_person_distinct_ids",
select_fields=requested_fields,
group_fields=["distinct_id"],
argmax_field="version",
deleted_field="is_deleted",
)

if "distinct_ids" in node.type.ctes:
comparison = clone_expr(
ast.CompareOperation(
op=ast.CompareOperationOp.In,
left=ast.Field(
chain=["distinct_id"], type=ast.FieldType(name="distinct_id", table_type=PersonDistinctIdsTable)
),
right=ast.SelectQuery(
select=[ast.Field(chain=["distinct_id"])],
select_from=ast.JoinExpr(table=ast.Field(chain=["distinct_ids"])),
),
),
clear_types=True,
clear_locations=True,
)
if select.where:
select.where = And(exprs=[comparison, select.where])
else:
select.where = comparison

return select


def join_with_person_distinct_ids_table(
join_to_add: LazyJoinToAdd,
Expand All @@ -50,7 +76,7 @@ def join_with_person_distinct_ids_table(

if not join_to_add.fields_accessed:
raise ResolutionError("No fields requested from person_distinct_ids")
join_expr = ast.JoinExpr(table=select_from_person_distinct_ids_table(join_to_add.fields_accessed))
join_expr = ast.JoinExpr(table=select_from_person_distinct_ids_table(join_to_add.fields_accessed, context, node))
join_expr.join_type = "INNER JOIN"
join_expr.alias = join_to_add.to_table
join_expr.constraint = ast.JoinConstraint(
Expand Down Expand Up @@ -82,7 +108,7 @@ class PersonDistinctIdsTable(LazyTable):
fields: dict[str, FieldOrTable] = PERSON_DISTINCT_IDS_FIELDS

def lazy_select(self, table_to_add: LazyTableToAdd, context, node):
return select_from_person_distinct_ids_table(table_to_add.fields_accessed)
return select_from_person_distinct_ids_table(table_to_add.fields_accessed, context, node)

def to_printed_clickhouse(self, context):
return "person_distinct_id2"
Expand Down
19 changes: 19 additions & 0 deletions posthog/hogql/database/schema/persons.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import cast
import posthoganalytics

from hogql_parser import parse_expr
from posthog.hogql.ast import SelectQuery, And
from posthog.hogql.constants import HogQLQuerySettings
from posthog.hogql.context import HogQLContext
Expand All @@ -21,6 +22,7 @@
from posthog.hogql.database.schema.util.where_clause_extractor import WhereClauseExtractor
from posthog.hogql.database.schema.persons_pdi import PersonsPDITable, persons_pdi_join
from posthog.hogql.errors import ResolutionError
from posthog.hogql.visitor import clone_expr
from posthog.models.organization import Organization
from posthog.schema import PersonsArgMaxVersion

Expand Down Expand Up @@ -56,6 +58,17 @@ def select_from_persons_table(join_or_table: LazyJoinToAdd | LazyTableToAdd, con
ast.SelectQuery,
parse_select(
"""
SELECT id FROM raw_persons WHERE (id, version) IN (
SELECT id, max(version) as version
FROM raw_persons
WHERE raw_persons.id in (select person_id from person_ids)
GROUP BY id
HAVING equals(argMax(raw_persons.is_deleted, raw_persons.version), 0)
AND argMax(raw_persons.created_at, raw_persons.version) < now() + interval 1 day
)
"""
if "person_ids" in node.type.ctes
else """
SELECT id FROM raw_persons WHERE (id, version) IN (
SELECT id, max(version) as version
FROM raw_persons
Expand Down Expand Up @@ -88,6 +101,12 @@ def select_from_persons_table(join_or_table: LazyJoinToAdd | LazyTableToAdd, con
timestamp_field_to_clamp="created_at",
)
select.settings = HogQLQuerySettings(optimize_aggregation_in_order=True)
if "person_ids" in node.type.ctes:
expr = parse_expr("raw_persons.id in (select person_id from person_ids)")
if select.where:
select.where = And(exprs=[select.where, expr])
else:
select.where = expr

if context.modifiers.optimizeJoinedFilters:
extractor = WhereClauseExtractor(context)
Expand Down
3 changes: 3 additions & 0 deletions posthog/hogql/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def visit_join_expr(self, node: ast.JoinExpr):

def visit_select_query(self, node: ast.SelectQuery):
# :TRICKY: when adding new fields, also add them to visit_select_query of resolver.py
# pass the CTEs of the node to select_froms (needed for nested joins to have access to CTEs)
if node.type is not None and node.type.ctes is not None and node.select_from is not None and hasattr(node.select_from.type, "ctes"):
node.select_from.type.ctes = {**node.type.ctes, **node.select_from.type.ctes}
self.visit(node.select_from)
if node.ctes is not None:
for expr in list(node.ctes.values()):
Expand Down
58 changes: 52 additions & 6 deletions posthog/hogql_queries/actors_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@
from typing import Optional
from collections.abc import Sequence, Iterator
from posthog.hogql import ast
from posthog.hogql.parser import parse_expr, parse_order_expr
from posthog.hogql.constants import HogQLQuerySettings
from posthog.hogql.parser import parse_expr, parse_order_expr, parse_select
from posthog.hogql.property import has_aggregation
from posthog.hogql_queries.actor_strategies import ActorStrategy, PersonStrategy, GroupStrategy
from posthog.hogql_queries.insights.insight_actors_query_runner import InsightActorsQueryRunner
from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator
from posthog.hogql_queries.query_runner import QueryRunner, get_query_runner
from posthog.schema import ActorsQuery, ActorsQueryResponse, CachedActorsQueryResponse, DashboardFilter
from posthog.schema import (
ActorsQuery,
ActorsQueryResponse,
CachedActorsQueryResponse,
DashboardFilter,
LifecycleQuery,
StickinessQuery,
TrendsQuery,
)
from posthog.settings import HOGQL_INCREASED_MAX_EXECUTION_TIME


class ActorsQueryRunner(QueryRunner):
Expand Down Expand Up @@ -230,12 +240,48 @@ def to_query(self) -> ast.SelectQuery:
order_by = []

with self.timings.measure("select"):
if self.query.source:
join_expr = self.source_table_join()
else:
join_expr = ast.JoinExpr(table=ast.Field(chain=[self.strategy.origin]))
assert self.source_query_runner is not None # For type checking
source_query = self.source_query_runner.to_actors_query()

# SelectUnionQuery (used by Stickiness) doesn't have settings
if hasattr(source_query, "settings"):
if source_query.settings is None:
source_query.settings = HogQLQuerySettings()
source_query.settings.use_query_cache = True
source_query.settings.query_cache_ttl = HOGQL_INCREASED_MAX_EXECUTION_TIME

source_id_chain = self.source_id_column(source_query)
source_alias = "source"

join_expr = ast.JoinExpr(
table=ast.Field(chain=[source_alias]),
next_join=ast.JoinExpr(
table=ast.Field(chain=[self.strategy.origin]),
join_type="INNER JOIN",
constraint=ast.JoinConstraint(
expr=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Field(chain=[self.strategy.origin, self.strategy.origin_id]),
right=ast.Field(chain=[source_alias, *source_id_chain]),
),
constraint_type="ON",
),
),
)

ctes = {
source_alias: ast.CTE(name=source_alias, expr=source_query, cte_type="subquery"),
}
if isinstance(self.strategy, PersonStrategy) and any(
isinstance(x, C) for x in [self.query.source.source] for C in (TrendsQuery,)
):
s = parse_select("SELECT distinct actor_id as person_id FROM source")
s.select_from.table = source_query
# This feels like it adds one extra level of SELECT which is unnecessary
ctes["person_ids"] = ast.CTE(name="person_ids", expr=s, cte_type="subquery")

stmt = ast.SelectQuery(
ctes=ctes,
select=columns,
select_from=join_expr,
where=where,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@
HAVING ifNull(equals(steps, max_steps), isNull(steps)
and isNull(max_steps)))
WHERE ifNull(in(steps, [2, 3]), 0)
ORDER BY aggregation_target ASC) AS source
ORDER BY aggregation_target ASC SETTINGS use_query_cache=1,
query_cache_ttl=600) AS source
INNER JOIN
(SELECT argMax(person.created_at, person.version) AS created_at,
person.id AS id
Expand Down Expand Up @@ -604,7 +605,8 @@
HAVING ifNull(equals(steps, max_steps), isNull(steps)
and isNull(max_steps)))
WHERE ifNull(in(steps, [1, 2, 3]), 0)
ORDER BY aggregation_target ASC) AS source
ORDER BY aggregation_target ASC SETTINGS use_query_cache=1,
query_cache_ttl=600) AS source
INNER JOIN
(SELECT argMax(person.created_at, person.version) AS created_at,
person.id AS id
Expand Down Expand Up @@ -723,7 +725,8 @@
HAVING ifNull(equals(steps, max_steps), isNull(steps)
and isNull(max_steps)))
WHERE ifNull(in(steps, [2, 3]), 0)
ORDER BY aggregation_target ASC) AS source
ORDER BY aggregation_target ASC SETTINGS use_query_cache=1,
query_cache_ttl=600) AS source
INNER JOIN
(SELECT argMax(person.created_at, person.version) AS created_at,
person.id AS id
Expand Down Expand Up @@ -842,7 +845,8 @@
HAVING ifNull(equals(steps, max_steps), isNull(steps)
and isNull(max_steps)))
WHERE ifNull(in(steps, [3]), 0)
ORDER BY aggregation_target ASC) AS source
ORDER BY aggregation_target ASC SETTINGS use_query_cache=1,
query_cache_ttl=600) AS source
INNER JOIN
(SELECT argMax(person.created_at, person.version) AS created_at,
person.id AS id
Expand Down
Loading
Loading