diff --git a/posthog/api/test/__snapshots__/test_cohort.ambr b/posthog/api/test/__snapshots__/test_cohort.ambr index 071e7fff34d8f..d471209df9f64 100644 --- a/posthog/api/test/__snapshots__/test_cohort.ambr +++ b/posthog/api/test/__snapshots__/test_cohort.ambr @@ -81,7 +81,7 @@ cohort_id FROM cohortpeople WHERE (team_id = 2 - AND cohort_id = '1' + AND cohort_id = '2' AND version < '2') ''' # --- @@ -177,7 +177,7 @@ cohort_id FROM cohortpeople WHERE (team_id = 2 - AND cohort_id = '1' + AND cohort_id = '2' AND version < '2') ''' # --- @@ -187,7 +187,7 @@ DELETE FROM cohortpeople WHERE (team_id = 2 - AND cohort_id = '1' + AND cohort_id = '2' AND version < '2') ''' # --- diff --git a/posthog/api/test/test_cohort.py b/posthog/api/test/test_cohort.py index 043d786e25241..52dea5f41a9e0 100644 --- a/posthog/api/test/test_cohort.py +++ b/posthog/api/test/test_cohort.py @@ -1,4 +1,5 @@ import json +from ee.clickhouse.materialized_columns.analyze import materialize from datetime import datetime, timedelta from typing import Optional, Any from unittest import mock @@ -12,11 +13,11 @@ from posthog.api.test.test_exports import TestExportMixin from posthog.clickhouse.client.execute import sync_execute -from posthog.models import FeatureFlag, Person +from posthog.models import FeatureFlag, Person, Action from posthog.models.async_deletion.async_deletion import AsyncDeletion, DeletionType from posthog.models.cohort import Cohort from posthog.models.team.team import Team -from posthog.schema import PropertyOperator +from posthog.schema import PropertyOperator, PersonsOnEventsMode from posthog.tasks.calculate_cohort import calculate_cohort_ch, calculate_cohort_from_list from posthog.tasks.tasks import clickhouse_clear_removed_data from posthog.test.base import ( @@ -143,6 +144,89 @@ def test_creating_update_and_calculating(self, patch_sync_execute, patch_calcula }, ) + @patch("posthog.api.cohort.report_user_action") + @patch("posthog.tasks.calculate_cohort.calculate_cohort_ch.delay", side_effect=calculate_cohort_ch) + @patch("posthog.models.cohort.util.sync_execute", side_effect=sync_execute) + def test_action_persons_on_events(self, patch_sync_execute, patch_calculate_cohort, patch_capture): + materialize("events", "team_id", table_column="person_properties") + self.team.modifiers = {"personsOnEventsMode": PersonsOnEventsMode.PERSON_ID_OVERRIDE_PROPERTIES_ON_EVENTS} + self.team.save() + _create_person( + team=self.team, + distinct_ids=[f"person_1"], + properties={"team_id": 5}, + ) + _create_person( + team=self.team, + distinct_ids=[f"person_2"], + properties={"team_id": 6}, + ) + _create_event( + team=self.team, + event="$pageview", + distinct_id="person_1", + timestamp=datetime.now() - timedelta(hours=12), + ) + action = Action.objects.create( + team=self.team, + steps_json=[ + { + "event": "$pageview", + "properties": [{"key": "team_id", "type": "person", "value": 5}], + } + ], + ) + + # Make sure the endpoint works with and without the trailing slash + response = self.client.post( + f"/api/projects/{self.team.id}/cohorts", + data={ + "name": "whatever", + "filters": { + "properties": { + "type": "OR", + "values": [ + { + "type": "AND", + "values": [ + { + "key": action.pk, + "type": "behavioral", + "value": "performed_event", + "negation": False, + "event_type": "actions", + "time_value": 30, + "time_interval": "day", + "explicit_datetime": "-30d", + } + ], + } + ], + } + }, + }, + ) + self.assertEqual(response.status_code, 201, response.content) + self.assertEqual(response.json()["created_by"]["id"], self.user.pk) + self.assertEqual(patch_calculate_cohort.call_count, 1) + self.assertEqual(patch_capture.call_count, 1) + + with self.capture_queries_startswith("INSERT INTO cohortpeople") as insert_statements: + response = self.client.patch( + f"/api/projects/{self.team.id}/cohorts/{response.json()['id']}", + data={ + "name": "whatever2", + "description": "A great cohort!", + "groups": [{"properties": {"team_id": 6}}], + "created_by": "something something", + "last_calculation": "some random date", + "errors_calculating": 100, + "deleted": False, + }, + ) + + self.assertIn(f"mat_pp_team_id", insert_statements[0]) + @patch("posthog.api.cohort.report_user_action") @patch("posthog.tasks.calculate_cohort.calculate_cohort_ch.delay") def test_list_cohorts_is_not_nplus1(self, patch_calculate_cohort, patch_capture): diff --git a/posthog/models/cohort/util.py b/posthog/models/cohort/util.py index b6eeac84a8395..1d18d632bd030 100644 --- a/posthog/models/cohort/util.py +++ b/posthog/models/cohort/util.py @@ -7,7 +7,7 @@ from django.conf import settings from django.utils import timezone from rest_framework.exceptions import ValidationError - +from posthog.queries.util import PersonPropertiesMode from posthog.clickhouse.client.connection import Workload from posthog.clickhouse.query_tagging import tag_queries from posthog.client import sync_execute @@ -65,6 +65,7 @@ def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext) ), cohort.team, cohort_pk=cohort.pk, + persons_on_events_mode=cohort.team.person_on_events_mode, ) query, params = query_builder.get_query() @@ -151,6 +152,7 @@ def get_entity_query( team_id: int, group_idx: Union[int, str], hogql_context: HogQLContext, + person_properties_mode: Optional[PersonPropertiesMode] = None, ) -> tuple[str, dict[str, str]]: if event_id: return f"event = %({f'event_{group_idx}'})s", {f"event_{group_idx}": event_id} @@ -161,6 +163,9 @@ def get_entity_query( action=action, prepend="_{}_action".format(group_idx), hogql_context=hogql_context, + person_properties_mode=person_properties_mode + if person_properties_mode + else PersonPropertiesMode.USING_SUBQUERY, ) return action_filter_query, action_params else: diff --git a/posthog/queries/foss_cohort_query.py b/posthog/queries/foss_cohort_query.py index d4925856afd94..a7e020158872e 100644 --- a/posthog/queries/foss_cohort_query.py +++ b/posthog/queries/foss_cohort_query.py @@ -613,6 +613,9 @@ def _get_entity( self._team_id, f"{prepend}_entity_{idx}", self._filter.hogql_context, + person_properties_mode=PersonPropertiesMode.DIRECT_ON_EVENTS + if self._person_on_events_mode != PersonsOnEventsMode.DISABLED + else None, ) elif event[0] == "events": self._add_event(str(event[1]))