Skip to content

Commit

Permalink
feat(cohorts): cohorts from HogQLQuery (#25311)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
mariusandra and github-actions[bot] authored Oct 3, 2024
1 parent 35d015c commit 869504a
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 14 deletions.
7 changes: 4 additions & 3 deletions frontend/src/lib/components/ExportButton/exportsLogic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { urls } from 'scenes/urls'

import { sidePanelStateLogic } from '~/layout/navigation-3000/sidepanel/sidePanelStateLogic'
import { cohortsModel } from '~/models/cohortsModel'
import { DataNode } from '~/queries/schema'
import { AnyDataNode } from '~/queries/schema'
import { CohortType, ExportContext, ExportedAssetType, ExporterFormat, LocalExportContext, SidePanelTab } from '~/types'

import type { exportsLogicType } from './exportsLogicType'
Expand All @@ -33,7 +33,7 @@ export const exportsLogic = kea<exportsLogicType>([
pollExportStatus: (exportedAsset: ExportedAssetType) => ({ exportedAsset }),
addFresh: (exportedAsset: ExportedAssetType) => ({ exportedAsset }),
removeFresh: (exportedAsset: ExportedAssetType) => ({ exportedAsset }),
createStaticCohort: (name: string, query: DataNode) => ({ query, name }),
createStaticCohort: (name: string, query: AnyDataNode) => ({ query, name }),
}),

connect({
Expand Down Expand Up @@ -140,8 +140,8 @@ export const exportsLogic = kea<exportsLogicType>([
})
},
createStaticCohort: async ({ query, name }) => {
const toastId = 'toast-' + Math.random()
try {
const toastId = 'toast-' + Math.random()
lemonToast.info('Saving cohort...', { toastId, autoClose: false })
const cohort: CohortType = await api.create('api/cohort', {
is_static: true,
Expand All @@ -159,6 +159,7 @@ export const exportsLogic = kea<exportsLogicType>([
},
})
} catch (e) {
lemonToast.dismiss(toastId)
lemonToast.error('Cohort save failed')
}
},
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/queries/nodes/DataTable/DataTableExport.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ export function DataTableExport({ query }: DataTableExportProps): JSX.Element |
onClick: () => {
LemonDialog.openForm({
title: 'Save as static cohort',
description: 'This will create a cohort with the current results of the query.',
description: 'This will create a cohort with the current list of people.',
initialValues: {
name: '',
},
Expand All @@ -289,7 +289,7 @@ export function DataTableExport({ query }: DataTableExportProps): JSX.Element |
<LemonInput
type="text"
data-attr="insight-name"
placeholder="Please enter a name for the cohort"
placeholder="Name of the new cohort"
autoFocus
/>
</LemonField>
Expand Down
54 changes: 53 additions & 1 deletion frontend/src/scenes/insights/InsightPageHeader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ import { EditAlertModal } from 'lib/components/Alerts/views/EditAlertModal'
import { ManageAlertsModal } from 'lib/components/Alerts/views/ManageAlertsModal'
import { EditableField } from 'lib/components/EditableField/EditableField'
import { ExportButton } from 'lib/components/ExportButton/ExportButton'
import { exportsLogic } from 'lib/components/ExportButton/exportsLogic'
import { ObjectTags } from 'lib/components/ObjectTags/ObjectTags'
import { PageHeader } from 'lib/components/PageHeader'
import { SharingModal } from 'lib/components/Sharing/SharingModal'
import { SubscribeButton, SubscriptionsModal } from 'lib/components/Subscriptions/SubscriptionsModal'
import { UserActivityIndicator } from 'lib/components/UserActivityIndicator/UserActivityIndicator'
import { LemonButton } from 'lib/lemon-ui/LemonButton'
import { More } from 'lib/lemon-ui/LemonButton/More'
import { LemonDialog } from 'lib/lemon-ui/LemonDialog'
import { LemonDivider } from 'lib/lemon-ui/LemonDivider'
import { LemonField } from 'lib/lemon-ui/LemonField'
import { LemonInput } from 'lib/lemon-ui/LemonInput'
import { LemonSwitch } from 'lib/lemon-ui/LemonSwitch'
import { deleteInsightWithUndo } from 'lib/utils/deleteWithUndo'
import { useState } from 'react'
Expand All @@ -34,6 +38,7 @@ import { userLogic } from 'scenes/userLogic'

import { tagsModel } from '~/models/tagsModel'
import { DataTableNode, NodeKind } from '~/queries/schema'
import { isDataTableNode, isDataVisualizationNode, isEventsQuery, isHogQLQuery } from '~/queries/utils'
import {
ExporterFormat,
InsightLogicProps,
Expand Down Expand Up @@ -68,10 +73,11 @@ export function InsightPageHeader({ insightLogicProps }: { insightLogicProps: In
const { duplicateInsight, loadInsights } = useActions(savedInsightsLogic)

// insightDataLogic
const { queryChanged, showQueryEditor, showDebugPanel, hogQL, exportContext } = useValues(
const { query, queryChanged, showQueryEditor, showDebugPanel, hogQL, exportContext } = useValues(
insightDataLogic(insightProps)
)
const { toggleQueryEditorPanel, toggleDebugPanel } = useActions(insightDataLogic(insightProps))
const { createStaticCohort } = useActions(exportsLogic)

// other logics
useMountedLogic(insightCommandLogic(insightProps))
Expand All @@ -83,6 +89,9 @@ export function InsightPageHeader({ insightLogicProps }: { insightLogicProps: In

const [addToDashboardModalOpen, setAddToDashboardModalOpenModal] = useState<boolean>(false)

const showCohortButton =
isDataTableNode(query) || isDataVisualizationNode(query) || isHogQLQuery(query) || isEventsQuery(query)

return (
<>
{hasDashboardItemId && (
Expand Down Expand Up @@ -256,6 +265,49 @@ export function InsightPageHeader({ insightLogicProps }: { insightLogicProps: In
>
Edit SQL directly
</LemonButton>
{showCohortButton && (
<LemonButton
data-attr="edit-insight-sql"
onClick={() => {
LemonDialog.openForm({
title: 'Save as static cohort',
description: (
<div className="mt-2">
Your query must export a <code>person_id</code>,{' '}
<code>actor_id</code> or <code>id</code> column,
which must match the <code>id</code> of the{' '}
<code>persons</code> table
</div>
),
initialValues: {
name: '',
},
content: (
<LemonField name="name">
<LemonInput
data-attr="insight-name"
placeholder="Name of the new cohort"
autoFocus
/>
</LemonField>
),
errors: {
name: (name) =>
!name ? 'You must enter a name' : undefined,
},
onSubmit: async ({ name }) => {
createStaticCohort(name, {
kind: NodeKind.HogQLQuery,
query: hogQL,
})
},
})
}}
fullWidth
>
Save as static cohort
</LemonButton>
)}
</>
)}
{hasDashboardItemId && (
Expand Down
13 changes: 9 additions & 4 deletions posthog/api/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from posthog.queries.trends.trends_actors import TrendsActors
from posthog.queries.trends.lifecycle_actors import LifecycleActors
from posthog.queries.util import get_earliest_timestamp
from posthog.schema import ActorsQuery
from posthog.schema import ActorsQuery, HogQLQuery
from posthog.tasks.calculate_cohort import (
calculate_cohort_from_list,
insert_cohort_from_feature_flag,
Expand Down Expand Up @@ -181,9 +181,12 @@ def validate_query(self, query: Optional[dict]) -> Optional[dict]:
return None
if not isinstance(query, dict):
raise ValidationError("Query must be a dictionary.")
if query.get("kind") != "ActorsQuery":
raise ValidationError(f"Query must be a ActorsQuery. Got: {query.get('kind')}")
ActorsQuery.model_validate(query)
if query.get("kind") == "ActorsQuery":
ActorsQuery.model_validate(query)
elif query.get("kind") == "HogQLQuery":
HogQLQuery.model_validate(query)
else:
raise ValidationError(f"Query must be an ActorsQuery or HogQLQuery. Got: {query.get('kind')}")
return query

def validate_filters(self, request_filters: dict):
Expand Down Expand Up @@ -268,6 +271,8 @@ def update(self, cohort: Cohort, validated_data: dict, *args: Any, **kwargs: Any
# You can't update a static cohort using the trend/stickiness thing
if request.FILES.get("csv"):
self._calculate_static_by_csv(request.FILES["csv"], cohort)
else:
update_cohort(cohort, initiating_user=request.user)
else:
update_cohort(cohort, initiating_user=request.user)

Expand Down
70 changes: 70 additions & 0 deletions posthog/api/test/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,76 @@ def test_creating_update_and_calculating_with_new_cohort_query_dynamic_error(sel
)
self.assertEqual(response.status_code, 400, response.content)

@patch("posthog.api.cohort.report_user_action")
def test_creating_with_query_and_fields(self, patch_capture):
_create_person(
distinct_ids=["p1"],
team_id=self.team.pk,
properties={"$some_prop": "something"},
)
_create_person(
distinct_ids=["p2"],
team_id=self.team.pk,
properties={"$some_prop": "not it"},
)
_create_person(
distinct_ids=["p3"],
team_id=self.team.pk,
properties={"$some_prop": "not it"},
)
_create_person(distinct_ids=["p4"], team_id=self.team.pk, properties={})
_create_event(team=self.team, event="$pageview", distinct_id="p4", timestamp=datetime.now())
_create_event(team=self.team, event="$pageview", distinct_id="p4", timestamp=datetime.now())
flush_persons_and_events()

def _calc(query: str) -> int:
response = self.client.post(
f"/api/projects/{self.team.id}/cohorts",
data={
"name": "cohort A",
"is_static": True,
"query": {
"kind": "HogQLQuery",
"query": query,
},
},
)
cohort_id = response.json()["id"]
while response.json()["is_calculating"]:
response = self.client.get(f"/api/projects/{self.team.id}/cohorts/{cohort_id}")
response = self.client.get(f"/api/projects/{self.team.id}/cohorts/{cohort_id}/persons/?cohort={cohort_id}")
return len(response.json()["results"])

# works with "actor_id"
self.assertEqual(2, _calc("select id as actor_id from persons where properties.$some_prop='not it'"))

# works with "person_id"
self.assertEqual(2, _calc("select id as person_id from persons where properties.$some_prop='not it'"))

# works with "id"
self.assertEqual(2, _calc("select id from persons where properties.$some_prop='not it'"))

# only "p4" had events
self.assertEqual(1, _calc("select person_id from events"))

# works with selecting anything from persons and events
self.assertEqual(4, _calc("select 1 from persons"))
self.assertEqual(1, _calc("select 1 from events"))

# raises on all other cases
response = self.client.post(
f"/api/projects/{self.team.id}/cohorts",
data={
"name": "cohort A",
"is_static": True,
"query": {
"kind": "HogQLQuery",
"query": "select 1 from groups",
},
},
)
self.assertEqual(response.status_code, 500, response.content)

@patch("posthog.api.cohort.report_user_action")
def test_cohort_with_is_set_filter_missing_value(self, patch_capture):
# regression test: Removing `value` was silently failing
Expand Down
31 changes: 28 additions & 3 deletions posthog/models/cohort/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from posthog.clickhouse.query_tagging import tag_queries
from posthog.client import sync_execute
from posthog.constants import PropertyOperatorType
from posthog.hogql import ast
from posthog.hogql.constants import LimitContext
from posthog.hogql.hogql import HogQLContext
from posthog.hogql.modifiers import create_default_modifiers_for_team
Expand Down Expand Up @@ -74,12 +75,36 @@ def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext)
def print_cohort_hogql_query(cohort: Cohort, hogql_context: HogQLContext) -> str:
from posthog.hogql_queries.query_runner import get_query_runner

persons_query = cast(dict, cohort.query)
persons_query["select"] = ["id as actor_id"]
if not cohort.query:
raise ValueError("Cohort has no query")

query = get_query_runner(
persons_query, team=cast(Team, cohort.team), limit_context=LimitContext.COHORT_CALCULATION
cast(dict, cohort.query), team=cast(Team, cohort.team), limit_context=LimitContext.COHORT_CALCULATION
).to_query()

select_queries: list[ast.SelectQuery] = [query] if isinstance(query, ast.SelectQuery) else query.select_queries
for select_query in select_queries:
columns: dict[str, ast.Expr] = {}
for expr in select_query.select:
if isinstance(expr, ast.Alias):
columns[expr.alias] = expr.expr
elif isinstance(expr, ast.Field):
columns[str(expr.chain[-1])] = expr
column: ast.Expr | None = columns.get("person_id") or columns.get("actor_id") or columns.get("id")
if isinstance(column, ast.Alias):
select_query.select = [ast.Alias(expr=column.expr, alias="actor_id")]
elif isinstance(column, ast.Field):
select_query.select = [ast.Alias(expr=column, alias="actor_id")]
else:
# Support the most common use cases
table = select_query.select_from.table if select_query.select_from else None
if isinstance(table, ast.Field) and table.chain[-1] == "events":
select_query.select = [ast.Alias(expr=ast.Field(chain=["person", "id"]), alias="actor_id")]
elif isinstance(table, ast.Field) and table.chain[-1] == "persons":
select_query.select = [ast.Alias(expr=ast.Field(chain=["id"]), alias="actor_id")]
else:
raise ValueError("Could not find a person_id, actor_id, or id column in the query")

hogql_context.enable_select_queries = True
hogql_context.limit_top_select = False
create_default_modifiers_for_team(cohort.team, hogql_context.modifiers)
Expand Down
2 changes: 1 addition & 1 deletion posthog/models/person/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@

INSERT_COHORT_ALL_PEOPLE_THROUGH_PERSON_ID = """
INSERT INTO {cohort_table} SELECT generateUUIDv4(), actor_id, %(cohort_id)s, %(team_id)s, %(_timestamp)s, 0 FROM (
SELECT actor_id FROM ({query})
SELECT DISTINCT actor_id FROM ({query})
)
"""

Expand Down

0 comments on commit 869504a

Please sign in to comment.