Skip to content

Commit

Permalink
fix(insights): support cohort filters in lifecycle insight
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra committed Dec 22, 2023
1 parent 409a5b9 commit 0ac72e9
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 11 deletions.
4 changes: 2 additions & 2 deletions frontend/src/queries/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@
"type": "string"
},
"value": {
"type": "number"
"type": "integer"
}
},
"required": ["key", "type", "value"],
Expand Down Expand Up @@ -2607,7 +2607,7 @@
"type": "string"
},
"order": {
"type": "number"
"type": "integer"
},
"type": {
"$ref": "#/definitions/EntityType"
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ export interface SessionPropertyFilter extends BasePropertyFilter {
export interface CohortPropertyFilter extends BasePropertyFilter {
type: PropertyFilterType.Cohort
key: 'id'
/** @asType integer */
value: number
}

Expand Down Expand Up @@ -1873,7 +1874,7 @@ export interface RetentionEntity {
kind?: NodeKind.ActionsNode | NodeKind.EventsNode
name?: string
type?: EntityType
// @asType integer
/** @asType integer */
order?: number
uuid?: string
custom_name?: string
Expand Down
8 changes: 6 additions & 2 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
escape_hogql_string,
)
from posthog.hogql.functions.mapping import ALL_EXPOSED_FUNCTION_NAMES, validate_function_args, HOGQL_COMPARISON_MAPPING
from posthog.hogql.modifiers import create_default_modifiers_for_team
from posthog.hogql.resolver import ResolverException, resolve_types
from posthog.hogql.resolver_utils import lookup_field_by_name
from posthog.hogql.transforms.in_cohort import resolve_in_cohorts
Expand All @@ -38,6 +39,7 @@
from posthog.hogql.visitor import Visitor, clone_expr
from posthog.models.property import PropertyName, TableColumn
from posthog.models.team.team import WeekStartDay
from posthog.models.team import Team
from posthog.models.utils import UUIDT
from posthog.schema import MaterializationMode
from posthog.utils import PersonOnEventsMode
Expand All @@ -56,12 +58,14 @@ def team_id_guard_for_table(table_type: Union[ast.TableType, ast.TableAliasType]
)


def to_printed_hogql(query: ast.Expr, team_id: int) -> str:
def to_printed_hogql(query: ast.Expr, team: Team) -> str:
"""Prints the HogQL query without mutating the node"""
return print_ast(
clone_expr(query),
dialect="hogql",
context=HogQLContext(team_id=team_id, enable_select_queries=True),
context=HogQLContext(
team_id=team.pk, enable_select_queries=True, modifiers=create_default_modifiers_for_team(team)
),
pretty=True,
)

Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql/test/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _pretty(self, query: str):

def test_to_printed_hogql(self):
expr = parse_select("select 1 + 2, 3 from events")
repsponse = to_printed_hogql(expr, self.team.pk)
repsponse = to_printed_hogql(expr, self.team)
self.assertEqual(repsponse, "SELECT\n plus(1, 2),\n 3\nFROM\n events\nLIMIT 10000")

def test_literals(self):
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql_queries/insights/lifecycle_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def to_actors_query(

def calculate(self) -> LifecycleQueryResponse:
query = self.to_query()
hogql = to_printed_hogql(query, self.team.pk)
hogql = to_printed_hogql(query, self.team)

response = execute_hogql_query(
query_type="LifecycleQuery",
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql_queries/insights/retention_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _refresh_frequency(self):

def calculate(self) -> RetentionQueryResponse:
query = self.to_query()
hogql = to_printed_hogql(query, self.team.pk)
hogql = to_printed_hogql(query, self.team)

response = execute_hogql_query(
query_type="RetentionQuery",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,99 @@
# name: TestLifecycleQueryRunner.test_cohort_filter
'

SELECT count(DISTINCT person_id)
FROM cohortpeople
WHERE team_id = 2
AND cohort_id = 2
AND version = NULL
'
---
# name: TestLifecycleQueryRunner.test_cohort_filter.1
'

SELECT count(DISTINCT person_id)
FROM cohortpeople
WHERE team_id = 2
AND cohort_id = 2
AND version = 0
'
---
# name: TestLifecycleQueryRunner.test_cohort_filter.2
'
SELECT groupArray(start_of_period) AS date,
groupArray(counts) AS total,
status AS status
FROM
(SELECT if(ifNull(equals(status, 'dormant'), 0), negate(sum(counts)), negate(negate(sum(counts)))) AS counts,
start_of_period AS start_of_period,
status AS status
FROM
(SELECT periods.start_of_period AS start_of_period,
0 AS counts,
sec.status AS status
FROM
(SELECT minus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-19 23:59:59', 6, 'UTC'))), toIntervalDay(numbers.number)) AS start_of_period
FROM numbers(dateDiff('day', toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-12 00:00:00', 6, 'UTC'))), toStartOfDay(plus(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-19 23:59:59', 6, 'UTC')), toIntervalDay(1))))) AS numbers) AS periods
CROSS JOIN
(SELECT status
FROM
(SELECT 1) ARRAY
JOIN ['new', 'returning', 'resurrecting', 'dormant'] AS status) AS sec
ORDER BY sec.status ASC, start_of_period ASC
UNION ALL SELECT start_of_period AS start_of_period,
count(DISTINCT person_id) AS counts,
status AS status
FROM
(SELECT events__pdi__person.id AS person_id,
min(toTimeZone(events__pdi__person.created_at, 'UTC')) AS created_at,
arraySort(groupUniqArray(toStartOfDay(toTimeZone(events.timestamp, 'UTC')))) AS all_activity,
arrayPopBack(arrayPushFront(all_activity, toStartOfDay(created_at))) AS previous_activity,
arrayPopFront(arrayPushBack(all_activity, toStartOfDay(parseDateTime64BestEffortOrNull('1970-01-01 00:00:00', 6, 'UTC')))) AS following_activity,
arrayMap((previous, current, index) -> if(ifNull(equals(previous, current), isNull(previous)
and isNull(current)), 'new', if(and(ifNull(equals(minus(current, toIntervalDay(1)), previous), isNull(minus(current, toIntervalDay(1)))
and isNull(previous)), ifNull(notEquals(index, 1), 1)), 'returning', 'resurrecting')), previous_activity, all_activity, arrayEnumerate(all_activity)) AS initial_status,
arrayMap((current, next) -> if(ifNull(equals(plus(current, toIntervalDay(1)), next), isNull(plus(current, toIntervalDay(1)))
and isNull(next)), '', 'dormant'), all_activity, following_activity) AS dormant_status,
arrayMap(x -> plus(x, toIntervalDay(1)), arrayFilter((current, is_dormant) -> ifNull(equals(is_dormant, 'dormant'), 0), all_activity, dormant_status)) AS dormant_periods,
arrayMap(x -> 'dormant', dormant_periods) AS dormant_label,
arrayConcat(arrayZip(all_activity, initial_status), arrayZip(dormant_periods, dormant_label)) AS temp_concat,
arrayJoin(temp_concat) AS period_status_pairs,
period_status_pairs.1 AS start_of_period,
period_status_pairs.2 AS status
FROM events
INNER JOIN
(SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id,
person_distinct_id2.distinct_id AS distinct_id
FROM person_distinct_id2
WHERE equals(person_distinct_id2.team_id, 2)
GROUP BY person_distinct_id2.distinct_id
HAVING ifNull(equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id)
INNER JOIN
(SELECT argMax(person.created_at, person.version) AS created_at,
person.id AS id
FROM person
WHERE equals(person.team_id, 2)
GROUP BY person.id
HAVING ifNull(equals(argMax(person.is_deleted, person.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS events__pdi__person ON equals(events__pdi.person_id, events__pdi__person.id)
WHERE and(equals(events.team_id, 2), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), minus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-12 00:00:00', 6, 'UTC'))), toIntervalDay(1))), less(toTimeZone(events.timestamp, 'UTC'), plus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-19 23:59:59', 6, 'UTC'))), toIntervalDay(1))), ifNull(in(person_id,
(SELECT cohortpeople.person_id AS person_id
FROM cohortpeople
WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 12))
GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version
HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0), equals(events.event, '$pageview'))
GROUP BY person_id)
GROUP BY start_of_period,
status)
WHERE and(ifNull(lessOrEquals(start_of_period, toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-19 23:59:59', 6, 'UTC')))), 0), ifNull(greaterOrEquals(start_of_period, toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-12 00:00:00', 6, 'UTC')))), 0))
GROUP BY start_of_period,
status
ORDER BY start_of_period ASC)
GROUP BY status
LIMIT 100 SETTINGS readonly=2,
max_execution_time=60,
allow_experimental_object_type=1
'
---
# name: TestLifecycleQueryRunner.test_sampling
'
SELECT groupArray(start_of_period) AS date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PropertyOperator,
PersonPropertyFilter,
ActionsNode,
CohortPropertyFilter,
)
from posthog.test.base import (
APIBaseTest,
Expand All @@ -22,7 +23,7 @@
flush_persons_and_events,
snapshot_clickhouse_queries,
)
from posthog.models import Action, ActionStep
from posthog.models import Action, ActionStep, Cohort
from posthog.models.instance_setting import get_instance_setting


Expand Down Expand Up @@ -1252,6 +1253,55 @@ def test_sampling(self):
),
).calculate()

@snapshot_clickhouse_queries
def test_cohort_filter(self):
self._create_events(
data=[
(
"p1",
[
"2020-01-11T12:00:00Z",
"2020-01-12T12:00:00Z",
"2020-01-13T12:00:00Z",
"2020-01-15T12:00:00Z",
"2020-01-17T12:00:00Z",
"2020-01-19T12:00:00Z",
],
),
("p2", ["2020-01-09T12:00:00Z", "2020-01-12T12:00:00Z"]),
("p3", ["2020-01-12T12:00:00Z"]),
("p4", ["2020-01-15T12:00:00Z"]),
]
)
flush_persons_and_events()
cohort = Cohort.objects.create(
team=self.team,
groups=[
{
"properties": [
{
"key": "email",
"value": ["[email protected]"],
"type": "person",
"operator": "exact",
}
]
}
],
)
cohort.calculate_people_ch(pending_version=0)
response = LifecycleQueryRunner(
team=self.team,
query=LifecycleQuery(
dateRange=DateRange(date_from="2020-01-12T00:00:00Z", date_to="2020-01-19T00:00:00Z"),
interval=IntervalType.day,
series=[EventsNode(event="$pageview")],
properties=[CohortPropertyFilter(value=cohort.pk)],
),
).calculate()
counts = [r["count"] for r in response.results]
assert counts == [0, 2, 3, -3]


def assertLifecycleResults(results, expected):
sorted_results = [{"status": r["status"], "data": r["data"]} for r in sorted(results, key=lambda r: r["status"])]
Expand Down
4 changes: 2 additions & 2 deletions posthog/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class CohortPropertyFilter(BaseModel):
key: Literal["id"] = "id"
label: Optional[str] = None
type: Literal["cohort"] = "cohort"
value: float
value: int


class CountPerActorMathType(str, Enum):
Expand Down Expand Up @@ -505,7 +505,7 @@ class RetentionEntity(BaseModel):
id: Optional[Union[str, float]] = None
kind: Optional[Kind] = None
name: Optional[str] = None
order: Optional[float] = None
order: Optional[int] = None
type: Optional[EntityType] = None
uuid: Optional[str] = None

Expand Down

0 comments on commit 0ac72e9

Please sign in to comment.