Skip to content

Commit

Permalink
hm
Browse files Browse the repository at this point in the history
  • Loading branch information
aspicer committed Jun 20, 2024
1 parent ad5d43a commit e8bfd53
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 54 deletions.
4 changes: 2 additions & 2 deletions posthog/api/test/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class TestCohort(TestExportMixin, ClickhouseTestMixin, APIBaseTest, QueryMatchingTest):
# select all queries for snapshots
def capture_select_queries(self):
return self.capture_queries(("INSERT INTO cohortpeople", "SELECT", "ALTER", "select", "DELETE"))
return self.capture_queries_startswith(("INSERT INTO cohortpeople", "SELECT", "ALTER", "select", "DELETE"))

def _get_cohort_activity(
self,
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_creating_update_and_calculating(self, patch_sync_execute, patch_calcula
},
)

with self.capture_queries("INSERT INTO cohortpeople") as insert_statements:
with self.capture_queries_startswith("INSERT INTO cohortpeople") as insert_statements:
response = self.client.patch(
f"/api/projects/{self.team.id}/cohorts/{response.json()['id']}",
data={
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def visit_join_expr(self, node: ast.JoinExpr):

def visit_select_query(self, node: ast.SelectQuery):
# :TRICKY: when adding new fields, also add them to visit_select_query of resolver.py
# pass the CTEs of the node to its children
# pass the CTEs of the node to select_froms (needed for nested joins to have access to CTEs)
if node.type is not None and node.type.ctes is not None and node.select_from is not None and hasattr(node.select_from.type, "ctes"):
node.select_from.type.ctes = {**node.type.ctes, **node.select_from.type.ctes}
self.visit(node.select_from)
Expand Down
15 changes: 7 additions & 8 deletions posthog/hogql_queries/actors_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,13 @@ def to_query(self) -> ast.SelectQuery:
ctes = {
source_alias: ast.CTE(name=source_alias, expr=source_query, cte_type="subquery"),
}
if True:
if isinstance(self.strategy, PersonStrategy) and any(
isinstance(x, C) for x in [self.query.source.source] for C in (TrendsQuery,)
):
s = parse_select("SELECT distinct actor_id as person_id FROM source")
s.select_from.table = source_query
# How to get rid of the extra superfluous select
ctes["person_ids"] = ast.CTE(name="person_ids", expr=s, cte_type="subquery")
if isinstance(self.strategy, PersonStrategy) and any(
isinstance(x, C) for x in [self.query.source.source] for C in (TrendsQuery,)
):
s = parse_select("SELECT distinct actor_id as person_id FROM source")
s.select_from.table = source_query
# This feels like it adds one extra level of SELECT which is unnecessary
ctes["person_ids"] = ast.CTE(name="person_ids", expr=s, cte_type="subquery")

stmt = ast.SelectQuery(
ctes=ctes,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Optional
import re

from freezegun import freeze_time

Expand Down Expand Up @@ -217,47 +218,53 @@ def test_insight_persons_trends_query_with_argmaxV2(self):
self.team.timezone = "US/Pacific"
self.team.save()

response = self.select(
"""
select * from (
<ActorsQuery select={['properties.name']}>
<InsightActorsQuery day='2020-01-09'>
<TrendsQuery
dateRange={<InsightDateRange date_from='2020-01-09' date_to='2020-01-19' />}
series={[<EventsNode event='$pageview' />]}
/>
</InsightActorsQuery>
</ActorsQuery>
with self.capture_queries(lambda query: re.match("^SELECT\s+name\s+AS\s+name", query)) as queries:
response = self.select(
"""
select * from (
<ActorsQuery select={['properties.name']}>
<InsightActorsQuery day='2020-01-09'>
<TrendsQuery
dateRange={<InsightDateRange date_from='2020-01-09' date_to='2020-01-19' />}
series={[<EventsNode event='$pageview' />]}
/>
</InsightActorsQuery>
</ActorsQuery>
)
""",
modifiers={"personsArgMaxVersion": PersonsArgMaxVersion.V2},
)
""",
modifiers={"personsArgMaxVersion": PersonsArgMaxVersion.V2},
)

self.assertEqual([("p2",)], response.results)
assert "in(distinct_id" in queries[0]
assert "in(person.id" in queries[0]

@snapshot_clickhouse_queries
def test_insight_persons_trends_query_with_argmaxV1(self):
self._create_test_events()
self.team.timezone = "US/Pacific"
self.team.save()

response = self.select(
"""
select * from (
<ActorsQuery select={['properties.name']}>
<InsightActorsQuery day='2020-01-09'>
<TrendsQuery
dateRange={<InsightDateRange date_from='2020-01-09' date_to='2020-01-19' />}
series={[<EventsNode event='$pageview' />]}
/>
</InsightActorsQuery>
</ActorsQuery>
with self.capture_queries(lambda query: re.match("^SELECT\s+name\s+AS\s+name", query)) as queries:
response = self.select(
"""
select * from (
<ActorsQuery select={['properties.name']}>
<InsightActorsQuery day='2020-01-09'>
<TrendsQuery
dateRange={<InsightDateRange date_from='2020-01-09' date_to='2020-01-19' />}
series={[<EventsNode event='$pageview' />]}
/>
</InsightActorsQuery>
</ActorsQuery>
)
""",
modifiers={"personsArgMaxVersion": PersonsArgMaxVersion.V1},
)
""",
modifiers={"personsArgMaxVersion": PersonsArgMaxVersion.V1},
)

self.assertEqual([("p2",)], response.results)
assert "in(distinct_id" in queries[0]
assert "in(person.id" in queries[0]

@snapshot_clickhouse_queries
def test_insight_persons_trends_groups_query(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@
groupUniqArray(100)(tuple(timestamp, uuid, `$session_id`, `$window_id`)) AS matching_events
FROM
(SELECT e.person_id AS actor_id,
e.distinct_id AS distinct_id,
toTimeZone(e.timestamp, 'UTC') AS timestamp,
e.uuid AS uuid,
e.`$session_id` AS `$session_id`,
Expand All @@ -259,12 +260,28 @@
GROUP BY groups.group_type_index,
groups.group_key) AS e__group_0 ON equals(e.`$group_0`, e__group_0.key)
WHERE and(equals(e.team_id, 2), equals(e.event, 'sign up'), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-02 00:00:00.000000', 6, 'UTC')), less(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-03 00:00:00.000000', 6, 'UTC')), ifNull(equals(e__group_0.properties___industry, 'technology'), 0)))
GROUP BY actor_id) AS source
GROUP BY actor_id SETTINGS use_query_cache=1,
query_cache_ttl=600) AS source
INNER JOIN
(SELECT argMax(person.created_at, person.version) AS created_at,
person.id AS id
FROM person
WHERE equals(person.team_id, 2)
WHERE and(equals(person.team_id, 2), in(person.id,
(SELECT person_ids.person_id AS person_id
FROM
(SELECT DISTINCT actor_id AS person_id
FROM
(SELECT actor_id AS actor_id, count() AS event_count, groupUniqArray(100)(tuple(timestamp, uuid, `$session_id`, `$window_id`)) AS matching_events
FROM
(SELECT e.person_id AS actor_id, e.distinct_id AS distinct_id, e.timestamp AS timestamp, e.uuid AS uuid, e.`$session_id` AS `$session_id`, e.`$window_id` AS `$window_id`
FROM events AS e
LEFT JOIN
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(groups.group_properties, 'industry'), ''), 'null'), '^"|"$', ''), groups._timestamp) AS properties___industry, groups.group_type_index AS index, groups.group_key AS key
FROM groups
WHERE and(equals(groups.team_id, 2), ifNull(equals(index, 0), 0))
GROUP BY groups.group_type_index, groups.group_key) AS e__group_0 ON equals(e.`$group_0`, e__group_0.key)
WHERE and(equals(e.team_id, 2), equals(e.event, 'sign up'), greaterOrEquals(e.timestamp, toDateTime64('2020-01-02 00:00:00.000000', 6, 'UTC')), less(e.timestamp, toDateTime64('2020-01-03 00:00:00.000000', 6, 'UTC')), ifNull(equals(e__group_0.properties___industry, 'technology'), 0)))
GROUP BY actor_id SETTINGS use_query_cache=1, query_cache_ttl=600)) AS person_ids)))
GROUP BY person.id
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(person.created_at, person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)) SETTINGS optimize_aggregation_in_order=1) AS persons ON equals(persons.id, source.actor_id)
ORDER BY source.event_count DESC
Expand Down Expand Up @@ -1050,6 +1067,7 @@
groupUniqArray(100)(tuple(timestamp, uuid, `$session_id`, `$window_id`)) AS matching_events
FROM
(SELECT e.person_id AS actor_id,
e.distinct_id AS distinct_id,
toTimeZone(e.timestamp, 'UTC') AS timestamp,
e.uuid AS uuid,
e.`$session_id` AS `$session_id`,
Expand All @@ -1072,12 +1090,33 @@
GROUP BY groups.group_type_index,
groups.group_key) AS e__group_0 ON equals(e.`$group_0`, e__group_0.key)
WHERE and(equals(e.team_id, 2), equals(e.event, 'sign up'), and(ifNull(equals(e__group_0.properties___industry, 'finance'), 0), ifNull(equals(e__group_2.properties___name, 'six'), 0)), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-02 00:00:00.000000', 6, 'UTC')), less(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-03 00:00:00.000000', 6, 'UTC'))))
GROUP BY actor_id) AS source
GROUP BY actor_id SETTINGS use_query_cache=1,
query_cache_ttl=600) AS source
INNER JOIN
(SELECT argMax(person.created_at, person.version) AS created_at,
person.id AS id
FROM person
WHERE equals(person.team_id, 2)
WHERE and(equals(person.team_id, 2), in(person.id,
(SELECT person_ids.person_id AS person_id
FROM
(SELECT DISTINCT actor_id AS person_id
FROM
(SELECT actor_id AS actor_id, count() AS event_count, groupUniqArray(100)(tuple(timestamp, uuid, `$session_id`, `$window_id`)) AS matching_events
FROM
(SELECT e.person_id AS actor_id, e.distinct_id AS distinct_id, e.timestamp AS timestamp, e.uuid AS uuid, e.`$session_id` AS `$session_id`, e.`$window_id` AS `$window_id`
FROM events AS e
LEFT JOIN
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(groups.group_properties, 'name'), ''), 'null'), '^"|"$', ''), groups._timestamp) AS properties___name, groups.group_type_index AS index, groups.group_key AS key
FROM groups
WHERE and(equals(groups.team_id, 2), ifNull(equals(index, 2), 0))
GROUP BY groups.group_type_index, groups.group_key) AS e__group_2 ON equals(e.`$group_2`, e__group_2.key)
LEFT JOIN
(SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(groups.group_properties, 'industry'), ''), 'null'), '^"|"$', ''), groups._timestamp) AS properties___industry, groups.group_type_index AS index, groups.group_key AS key
FROM groups
WHERE and(equals(groups.team_id, 2), ifNull(equals(index, 0), 0))
GROUP BY groups.group_type_index, groups.group_key) AS e__group_0 ON equals(e.`$group_0`, e__group_0.key)
WHERE and(equals(e.team_id, 2), equals(e.event, 'sign up'), and(ifNull(equals(e__group_0.properties___industry, 'finance'), 0), ifNull(equals(e__group_2.properties___name, 'six'), 0)), greaterOrEquals(e.timestamp, toDateTime64('2020-01-02 00:00:00.000000', 6, 'UTC')), less(e.timestamp, toDateTime64('2020-01-03 00:00:00.000000', 6, 'UTC'))))
GROUP BY actor_id SETTINGS use_query_cache=1, query_cache_ttl=600)) AS person_ids)))
GROUP BY person.id
HAVING and(ifNull(equals(argMax(person.is_deleted, person.version), 0), 0), ifNull(less(argMax(person.created_at, person.version), plus(now64(6, 'UTC'), toIntervalDay(1))), 0)) SETTINGS optimize_aggregation_in_order=1) AS persons ON equals(persons.id, source.actor_id)
ORDER BY source.event_count DESC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
TrendsFilter,
TrendsQuery,
CompareFilter,
BreakdownType,
PersonPropertyFilter,
)
from posthog.settings import HOGQL_INCREASED_MAX_EXECUTION_TIME

Expand Down Expand Up @@ -166,7 +168,6 @@ def is_total_value(self) -> bool:
return self.trends_display.is_total_value()

def build_actors_query(self) -> ast.SelectQuery | ast.SelectUnionQuery:
# Insert CTE here
cte_events_query = self._cte_events_query()
if cte_events_query.settings is None:
cte_events_query.settings = HogQLQuerySettings()
Expand Down Expand Up @@ -215,13 +216,14 @@ def _get_events_query(self) -> ast.SelectQuery:

def _cte_events_query(self) -> ast.SelectQuery:
return ast.SelectQuery(
select=[ast.Field(chain=["*"])], # Filter this down to save space
# Could filter this down to what we actually use to save memory
select=[ast.Field(chain=["*"])],
select_from=ast.JoinExpr(
table=ast.Field(chain=["events"]),
alias="e",
sample=self._sample_expr(),
),
where=self._cte_events_where_expr(),
where=self._persons_cte_events_where_expr(),
)

def _get_actor_value_expr(self) -> ast.Expr:
Expand All @@ -248,13 +250,17 @@ def _events_where_expr(self, with_breakdown_expr: bool = True) -> ast.And:
]
)

def _cte_events_where_expr(self, with_breakdown_expr: bool = True) -> ast.And:
def _persons_cte_events_where_expr(self, with_breakdown_expr: bool = True) -> ast.And:
return ast.And(
exprs=[
*self._entity_where_expr(),
# *self._prop_where_expr(),
*self._date_where_expr(),
*(self._breakdown_where_expr() if with_breakdown_expr else []),
*(
self._breakdown_where_expr()
if with_breakdown_expr and self.trends_query.breakdownFilter.breakdown_type != BreakdownType.PERSON
else []
),
*self._filter_empty_actors_expr(),
]
)
Expand Down Expand Up @@ -293,12 +299,13 @@ def _entity_where_expr(self) -> list[ast.Expr]:

return conditions

def _prop_where_expr(self) -> list[ast.Expr]:
def _prop_where_expr(self, exclude_person_props=False) -> list[ast.Expr]:
conditions: list[ast.Expr] = []

# Filter Test Accounts
if (
self.trends_query.filterTestAccounts
not exclude_person_props
and self.trends_query.filterTestAccounts
and isinstance(self.team.test_account_filters, list)
and len(self.team.test_account_filters) > 0
):
Expand All @@ -307,7 +314,10 @@ def _prop_where_expr(self) -> list[ast.Expr]:

# Properties
if self.trends_query.properties is not None and self.trends_query.properties != []:
conditions.append(property_to_expr(self.trends_query.properties, self.team))
properties = self.trends_query.properties
if exclude_person_props:
properties = [x for x in properties if isinstance(x, PersonPropertyFilter)]
conditions.append(property_to_expr(properties, self.team))

return conditions

Expand Down
9 changes: 6 additions & 3 deletions posthog/test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,13 @@ class ClickhouseTestMixin(QueryMatchingTest):
snapshot: Any

def capture_select_queries(self):
return self.capture_queries(("SELECT", "WITH", "select", "with"))
return self.capture_queries_startswith(("SELECT", "WITH", "select", "with"))

def capture_queries_startswith(self, query_prefixes: Union[str, tuple[str, ...]]):
return self.capture_queries(lambda x: x.startswith(query_prefixes))

@contextmanager
def capture_queries(self, query_prefixes: Union[str, tuple[str, ...]]):
def capture_queries(self, query_filter: Callable[[str], bool]):
queries = []
original_get_client = ch_pool.get_client

Expand All @@ -901,7 +904,7 @@ def get_client():
original_client_execute = client.execute

def execute_wrapper(query, *args, **kwargs):
if sqlparse.format(query, strip_comments=True).strip().startswith(query_prefixes):
if query_filter(sqlparse.format(query, strip_comments=True).strip()):
queries.append(query)
return original_client_execute(query, *args, **kwargs)

Expand Down

0 comments on commit e8bfd53

Please sign in to comment.