From 869504a29fcab75f08838cbaf9e0ca458bb0e6f9 Mon Sep 17 00:00:00 2001 From: Marius Andra Date: Thu, 3 Oct 2024 10:01:57 +0200 Subject: [PATCH] feat(cohorts): cohorts from HogQLQuery (#25311) Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> --- .../components/ExportButton/exportsLogic.ts | 7 +- .../nodes/DataTable/DataTableExport.tsx | 4 +- .../src/scenes/insights/InsightPageHeader.tsx | 54 +++++++++++++- posthog/api/cohort.py | 13 ++-- posthog/api/test/test_cohort.py | 70 +++++++++++++++++++ posthog/models/cohort/util.py | 31 +++++++- posthog/models/person/sql.py | 2 +- 7 files changed, 167 insertions(+), 14 deletions(-) diff --git a/frontend/src/lib/components/ExportButton/exportsLogic.ts b/frontend/src/lib/components/ExportButton/exportsLogic.ts index 77e7a32b1ca13..e633513874f02 100644 --- a/frontend/src/lib/components/ExportButton/exportsLogic.ts +++ b/frontend/src/lib/components/ExportButton/exportsLogic.ts @@ -11,7 +11,7 @@ import { urls } from 'scenes/urls' import { sidePanelStateLogic } from '~/layout/navigation-3000/sidepanel/sidePanelStateLogic' import { cohortsModel } from '~/models/cohortsModel' -import { DataNode } from '~/queries/schema' +import { AnyDataNode } from '~/queries/schema' import { CohortType, ExportContext, ExportedAssetType, ExporterFormat, LocalExportContext, SidePanelTab } from '~/types' import type { exportsLogicType } from './exportsLogicType' @@ -33,7 +33,7 @@ export const exportsLogic = kea([ pollExportStatus: (exportedAsset: ExportedAssetType) => ({ exportedAsset }), addFresh: (exportedAsset: ExportedAssetType) => ({ exportedAsset }), removeFresh: (exportedAsset: ExportedAssetType) => ({ exportedAsset }), - createStaticCohort: (name: string, query: DataNode) => ({ query, name }), + createStaticCohort: (name: string, query: AnyDataNode) => ({ query, name }), }), connect({ @@ -140,8 +140,8 @@ export const exportsLogic = kea([ }) }, createStaticCohort: async ({ query, name }) => { + const toastId = 'toast-' + Math.random() try { - const toastId = 'toast-' + Math.random() lemonToast.info('Saving cohort...', { toastId, autoClose: false }) const cohort: CohortType = await api.create('api/cohort', { is_static: true, @@ -159,6 +159,7 @@ export const exportsLogic = kea([ }, }) } catch (e) { + lemonToast.dismiss(toastId) lemonToast.error('Cohort save failed') } }, diff --git a/frontend/src/queries/nodes/DataTable/DataTableExport.tsx b/frontend/src/queries/nodes/DataTable/DataTableExport.tsx index 87710b8bde41e..29f202822f132 100644 --- a/frontend/src/queries/nodes/DataTable/DataTableExport.tsx +++ b/frontend/src/queries/nodes/DataTable/DataTableExport.tsx @@ -280,7 +280,7 @@ export function DataTableExport({ query }: DataTableExportProps): JSX.Element | onClick: () => { LemonDialog.openForm({ title: 'Save as static cohort', - description: 'This will create a cohort with the current results of the query.', + description: 'This will create a cohort with the current list of people.', initialValues: { name: '', }, @@ -289,7 +289,7 @@ export function DataTableExport({ query }: DataTableExportProps): JSX.Element | diff --git a/frontend/src/scenes/insights/InsightPageHeader.tsx b/frontend/src/scenes/insights/InsightPageHeader.tsx index 28602b24be8ea..705bca46c2266 100644 --- a/frontend/src/scenes/insights/InsightPageHeader.tsx +++ b/frontend/src/scenes/insights/InsightPageHeader.tsx @@ -8,6 +8,7 @@ import { EditAlertModal } from 'lib/components/Alerts/views/EditAlertModal' import { ManageAlertsModal } from 'lib/components/Alerts/views/ManageAlertsModal' import { EditableField } from 'lib/components/EditableField/EditableField' import { ExportButton } from 'lib/components/ExportButton/ExportButton' +import { exportsLogic } from 'lib/components/ExportButton/exportsLogic' import { ObjectTags } from 'lib/components/ObjectTags/ObjectTags' import { PageHeader } from 'lib/components/PageHeader' import { SharingModal } from 'lib/components/Sharing/SharingModal' @@ -15,7 +16,10 @@ import { SubscribeButton, SubscriptionsModal } from 'lib/components/Subscription import { UserActivityIndicator } from 'lib/components/UserActivityIndicator/UserActivityIndicator' import { LemonButton } from 'lib/lemon-ui/LemonButton' import { More } from 'lib/lemon-ui/LemonButton/More' +import { LemonDialog } from 'lib/lemon-ui/LemonDialog' import { LemonDivider } from 'lib/lemon-ui/LemonDivider' +import { LemonField } from 'lib/lemon-ui/LemonField' +import { LemonInput } from 'lib/lemon-ui/LemonInput' import { LemonSwitch } from 'lib/lemon-ui/LemonSwitch' import { deleteInsightWithUndo } from 'lib/utils/deleteWithUndo' import { useState } from 'react' @@ -34,6 +38,7 @@ import { userLogic } from 'scenes/userLogic' import { tagsModel } from '~/models/tagsModel' import { DataTableNode, NodeKind } from '~/queries/schema' +import { isDataTableNode, isDataVisualizationNode, isEventsQuery, isHogQLQuery } from '~/queries/utils' import { ExporterFormat, InsightLogicProps, @@ -68,10 +73,11 @@ export function InsightPageHeader({ insightLogicProps }: { insightLogicProps: In const { duplicateInsight, loadInsights } = useActions(savedInsightsLogic) // insightDataLogic - const { queryChanged, showQueryEditor, showDebugPanel, hogQL, exportContext } = useValues( + const { query, queryChanged, showQueryEditor, showDebugPanel, hogQL, exportContext } = useValues( insightDataLogic(insightProps) ) const { toggleQueryEditorPanel, toggleDebugPanel } = useActions(insightDataLogic(insightProps)) + const { createStaticCohort } = useActions(exportsLogic) // other logics useMountedLogic(insightCommandLogic(insightProps)) @@ -83,6 +89,9 @@ export function InsightPageHeader({ insightLogicProps }: { insightLogicProps: In const [addToDashboardModalOpen, setAddToDashboardModalOpenModal] = useState(false) + const showCohortButton = + isDataTableNode(query) || isDataVisualizationNode(query) || isHogQLQuery(query) || isEventsQuery(query) + return ( <> {hasDashboardItemId && ( @@ -256,6 +265,49 @@ export function InsightPageHeader({ insightLogicProps }: { insightLogicProps: In > Edit SQL directly + {showCohortButton && ( + { + LemonDialog.openForm({ + title: 'Save as static cohort', + description: ( +
+ Your query must export a person_id,{' '} + actor_id or id column, + which must match the id of the{' '} + persons table +
+ ), + initialValues: { + name: '', + }, + content: ( + + + + ), + errors: { + name: (name) => + !name ? 'You must enter a name' : undefined, + }, + onSubmit: async ({ name }) => { + createStaticCohort(name, { + kind: NodeKind.HogQLQuery, + query: hogQL, + }) + }, + }) + }} + fullWidth + > + Save as static cohort +
+ )} )} {hasDashboardItemId && ( diff --git a/posthog/api/cohort.py b/posthog/api/cohort.py index 543ef9825a749..f6e92d257b40a 100644 --- a/posthog/api/cohort.py +++ b/posthog/api/cohort.py @@ -75,7 +75,7 @@ from posthog.queries.trends.trends_actors import TrendsActors from posthog.queries.trends.lifecycle_actors import LifecycleActors from posthog.queries.util import get_earliest_timestamp -from posthog.schema import ActorsQuery +from posthog.schema import ActorsQuery, HogQLQuery from posthog.tasks.calculate_cohort import ( calculate_cohort_from_list, insert_cohort_from_feature_flag, @@ -181,9 +181,12 @@ def validate_query(self, query: Optional[dict]) -> Optional[dict]: return None if not isinstance(query, dict): raise ValidationError("Query must be a dictionary.") - if query.get("kind") != "ActorsQuery": - raise ValidationError(f"Query must be a ActorsQuery. Got: {query.get('kind')}") - ActorsQuery.model_validate(query) + if query.get("kind") == "ActorsQuery": + ActorsQuery.model_validate(query) + elif query.get("kind") == "HogQLQuery": + HogQLQuery.model_validate(query) + else: + raise ValidationError(f"Query must be an ActorsQuery or HogQLQuery. Got: {query.get('kind')}") return query def validate_filters(self, request_filters: dict): @@ -268,6 +271,8 @@ def update(self, cohort: Cohort, validated_data: dict, *args: Any, **kwargs: Any # You can't update a static cohort using the trend/stickiness thing if request.FILES.get("csv"): self._calculate_static_by_csv(request.FILES["csv"], cohort) + else: + update_cohort(cohort, initiating_user=request.user) else: update_cohort(cohort, initiating_user=request.user) diff --git a/posthog/api/test/test_cohort.py b/posthog/api/test/test_cohort.py index eaf2c4a3c6f41..043d786e25241 100644 --- a/posthog/api/test/test_cohort.py +++ b/posthog/api/test/test_cohort.py @@ -970,6 +970,76 @@ def test_creating_update_and_calculating_with_new_cohort_query_dynamic_error(sel ) self.assertEqual(response.status_code, 400, response.content) + @patch("posthog.api.cohort.report_user_action") + def test_creating_with_query_and_fields(self, patch_capture): + _create_person( + distinct_ids=["p1"], + team_id=self.team.pk, + properties={"$some_prop": "something"}, + ) + _create_person( + distinct_ids=["p2"], + team_id=self.team.pk, + properties={"$some_prop": "not it"}, + ) + _create_person( + distinct_ids=["p3"], + team_id=self.team.pk, + properties={"$some_prop": "not it"}, + ) + _create_person(distinct_ids=["p4"], team_id=self.team.pk, properties={}) + _create_event(team=self.team, event="$pageview", distinct_id="p4", timestamp=datetime.now()) + _create_event(team=self.team, event="$pageview", distinct_id="p4", timestamp=datetime.now()) + flush_persons_and_events() + + def _calc(query: str) -> int: + response = self.client.post( + f"/api/projects/{self.team.id}/cohorts", + data={ + "name": "cohort A", + "is_static": True, + "query": { + "kind": "HogQLQuery", + "query": query, + }, + }, + ) + cohort_id = response.json()["id"] + while response.json()["is_calculating"]: + response = self.client.get(f"/api/projects/{self.team.id}/cohorts/{cohort_id}") + response = self.client.get(f"/api/projects/{self.team.id}/cohorts/{cohort_id}/persons/?cohort={cohort_id}") + return len(response.json()["results"]) + + # works with "actor_id" + self.assertEqual(2, _calc("select id as actor_id from persons where properties.$some_prop='not it'")) + + # works with "person_id" + self.assertEqual(2, _calc("select id as person_id from persons where properties.$some_prop='not it'")) + + # works with "id" + self.assertEqual(2, _calc("select id from persons where properties.$some_prop='not it'")) + + # only "p4" had events + self.assertEqual(1, _calc("select person_id from events")) + + # works with selecting anything from persons and events + self.assertEqual(4, _calc("select 1 from persons")) + self.assertEqual(1, _calc("select 1 from events")) + + # raises on all other cases + response = self.client.post( + f"/api/projects/{self.team.id}/cohorts", + data={ + "name": "cohort A", + "is_static": True, + "query": { + "kind": "HogQLQuery", + "query": "select 1 from groups", + }, + }, + ) + self.assertEqual(response.status_code, 500, response.content) + @patch("posthog.api.cohort.report_user_action") def test_cohort_with_is_set_filter_missing_value(self, patch_capture): # regression test: Removing `value` was silently failing diff --git a/posthog/models/cohort/util.py b/posthog/models/cohort/util.py index 98f04b2a16808..b3290d82dad9e 100644 --- a/posthog/models/cohort/util.py +++ b/posthog/models/cohort/util.py @@ -12,6 +12,7 @@ from posthog.clickhouse.query_tagging import tag_queries from posthog.client import sync_execute from posthog.constants import PropertyOperatorType +from posthog.hogql import ast from posthog.hogql.constants import LimitContext from posthog.hogql.hogql import HogQLContext from posthog.hogql.modifiers import create_default_modifiers_for_team @@ -74,12 +75,36 @@ def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext) def print_cohort_hogql_query(cohort: Cohort, hogql_context: HogQLContext) -> str: from posthog.hogql_queries.query_runner import get_query_runner - persons_query = cast(dict, cohort.query) - persons_query["select"] = ["id as actor_id"] + if not cohort.query: + raise ValueError("Cohort has no query") + query = get_query_runner( - persons_query, team=cast(Team, cohort.team), limit_context=LimitContext.COHORT_CALCULATION + cast(dict, cohort.query), team=cast(Team, cohort.team), limit_context=LimitContext.COHORT_CALCULATION ).to_query() + select_queries: list[ast.SelectQuery] = [query] if isinstance(query, ast.SelectQuery) else query.select_queries + for select_query in select_queries: + columns: dict[str, ast.Expr] = {} + for expr in select_query.select: + if isinstance(expr, ast.Alias): + columns[expr.alias] = expr.expr + elif isinstance(expr, ast.Field): + columns[str(expr.chain[-1])] = expr + column: ast.Expr | None = columns.get("person_id") or columns.get("actor_id") or columns.get("id") + if isinstance(column, ast.Alias): + select_query.select = [ast.Alias(expr=column.expr, alias="actor_id")] + elif isinstance(column, ast.Field): + select_query.select = [ast.Alias(expr=column, alias="actor_id")] + else: + # Support the most common use cases + table = select_query.select_from.table if select_query.select_from else None + if isinstance(table, ast.Field) and table.chain[-1] == "events": + select_query.select = [ast.Alias(expr=ast.Field(chain=["person", "id"]), alias="actor_id")] + elif isinstance(table, ast.Field) and table.chain[-1] == "persons": + select_query.select = [ast.Alias(expr=ast.Field(chain=["id"]), alias="actor_id")] + else: + raise ValueError("Could not find a person_id, actor_id, or id column in the query") + hogql_context.enable_select_queries = True hogql_context.limit_top_select = False create_default_modifiers_for_team(cohort.team, hogql_context.modifiers) diff --git a/posthog/models/person/sql.py b/posthog/models/person/sql.py index f08a4144b9ec1..551d34412fa15 100644 --- a/posthog/models/person/sql.py +++ b/posthog/models/person/sql.py @@ -405,7 +405,7 @@ INSERT_COHORT_ALL_PEOPLE_THROUGH_PERSON_ID = """ INSERT INTO {cohort_table} SELECT generateUUIDv4(), actor_id, %(cohort_id)s, %(team_id)s, %(_timestamp)s, 0 FROM ( - SELECT actor_id FROM ({query}) + SELECT DISTINCT actor_id FROM ({query}) ) """