diff --git a/posthog/hogql_queries/experiments/experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/experiment_trends_query_runner.py index 06619c4dfeee1..65d4fd179d1ec 100644 --- a/posthog/hogql_queries/experiments/experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/experiment_trends_query_runner.py @@ -174,30 +174,49 @@ def _prepare_exposure_query(self) -> TrendsQuery: if uses_math_aggregation: prepared_exposure_query = TrendsQuery(**self.query.count_query.model_dump()) - count_event = self.query.count_query.series[0] + prepared_exposure_query.dateRange = self._get_insight_date_range() + prepared_exposure_query.trendsFilter = TrendsFilter(display=ChartDisplayType.ACTIONS_LINE_GRAPH_CUMULATIVE) - if hasattr(count_event, "event"): - prepared_exposure_query.dateRange = self._get_insight_date_range() - prepared_exposure_query.breakdownFilter = self._get_event_breakdown_filter() - prepared_exposure_query.trendsFilter = TrendsFilter( - display=ChartDisplayType.ACTIONS_LINE_GRAPH_CUMULATIVE - ) - prepared_exposure_query.series = [ - EventsNode( - event=count_event.event, - math=BaseMathType.DAU, - ) - ] + # For a data warehouse query, we can use the unique users for the series + if self._is_data_warehouse_query(prepared_exposure_query): + prepared_exposure_query.breakdownFilter = self._get_data_warehouse_breakdown_filter() + prepared_exposure_query.series[0].math = BaseMathType.DAU + prepared_exposure_query.series[0].math_property = None + prepared_exposure_query.series[0].math_property_type = None prepared_exposure_query.properties = [ - EventPropertyFilter( - key=self.breakdown_key, + DataWarehousePropertyFilter( + key="events.event", + value="$feature_flag_called", + operator=PropertyOperator.EXACT, + type="data_warehouse", + ), + DataWarehousePropertyFilter( + key=f"events.properties.{self.breakdown_key}", value=self.variants, operator=PropertyOperator.EXACT, - type="event", - ) + type="data_warehouse", + ), ] else: - raise ValueError("Expected first series item to have an 'event' attribute") + count_event = self.query.count_query.series[0] + if hasattr(count_event, "event"): + prepared_exposure_query.breakdownFilter = self._get_event_breakdown_filter() + prepared_exposure_query.series = [ + EventsNode( + event=count_event.event, + math=BaseMathType.DAU, + ) + ] + prepared_exposure_query.properties = [ + EventPropertyFilter( + key=self.breakdown_key, + value=self.variants, + operator=PropertyOperator.EXACT, + type="event", + ) + ] + else: + raise ValueError("Expected first series item to have an 'event' attribute") # 2. Otherwise, if an exposure query is provided, we use it as is, adapting the date range and breakdown elif self.query.exposure_query: diff --git a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py index 4402afde55eec..c8596256254d5 100644 --- a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py @@ -489,7 +489,7 @@ def test_query_runner_with_holdout(self): self.assertEqual(test_result.absolute_exposure, 9) self.assertEqual(holdout_result.absolute_exposure, 4) - def test_query_runner_with_data_warehouse_series(self): + def test_query_runner_with_data_warehouse_series_total_count(self): table_name = self.create_data_warehouse_table_with_payments() feature_flag = self.create_feature_flag() @@ -509,6 +509,7 @@ def test_query_runner_with_data_warehouse_series(self): id_field="id", table_name=table_name, timestamp_field="dw_timestamp", + math="total", ) ] ) @@ -583,6 +584,115 @@ def test_query_runner_with_data_warehouse_series(self): self.assertEqual(control_result.absolute_exposure, 9) self.assertEqual(test_result.absolute_exposure, 9) + def test_query_runner_with_data_warehouse_series_avg_amount(self): + table_name = self.create_data_warehouse_table_with_payments() + + feature_flag = self.create_feature_flag() + experiment = self.create_experiment( + feature_flag=feature_flag, + start_date=datetime(2023, 1, 1), + end_date=datetime(2023, 1, 10), + ) + + feature_flag_property = f"$feature/{feature_flag.key}" + + count_query = TrendsQuery( + series=[ + DataWarehouseNode( + id=table_name, + distinct_id_field="dw_distinct_id", + id_field="id", + table_name=table_name, + timestamp_field="dw_timestamp", + math="avg", + math_property="amount", + math_property_type="data_warehouse_properties", + ) + ] + ) + exposure_query = TrendsQuery(series=[EventsNode(event="$feature_flag_called")]) + + experiment_query = ExperimentTrendsQuery( + experiment_id=experiment.id, + kind="ExperimentTrendsQuery", + count_query=count_query, + exposure_query=exposure_query, + ) + + experiment.metrics = [{"type": "primary", "query": experiment_query.model_dump()}] + experiment.save() + + # Populate exposure events + for variant, count in [("control", 7), ("test", 9)]: + for i in range(count): + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id=f"user_{variant}_{i}", + properties={feature_flag_property: variant}, + timestamp=datetime(2023, 1, i + 1), + ) + + # "user_test_3" first exposure (feature_flag_property="control") is on 2023-01-03 + # "user_test_3" relevant exposure (feature_flag_property="test") is on 2023-01-04 + # "user_test_3" other event (feature_flag_property="control" is on 2023-01-05 + # "user_test_3" purchase is on 2023-01-06 + # "user_test_3" second exposure (feature_flag_property="control") is on 2023-01-09 + # "user_test_3" should fall into the "test" variant, not the "control" variant + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id="user_test_3", + properties={feature_flag_property: "control"}, + timestamp=datetime(2023, 1, 3), + ) + _create_event( + team=self.team, + event="Some other event", + distinct_id="user_test_3", + properties={feature_flag_property: "control"}, + timestamp=datetime(2023, 1, 5), + ) + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id="user_test_3", + properties={feature_flag_property: "control"}, + timestamp=datetime(2023, 1, 9), + ) + + flush_persons_and_events() + + query_runner = ExperimentTrendsQueryRunner( + query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team + ) + with freeze_time("2023-01-07"): + result = query_runner.calculate() + + trend_result = cast(ExperimentTrendsQueryResponse, result) + + self.assertEqual(len(result.variants), 2) + + control_result = next(variant for variant in trend_result.variants if variant.key == "control") + test_result = next(variant for variant in trend_result.variants if variant.key == "test") + + control_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "control") + test_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "test") + + self.assertEqual(control_result.count, 100) + self.assertEqual(test_result.count, 205) + self.assertEqual(control_result.absolute_exposure, 1) + self.assertEqual(test_result.absolute_exposure, 3) + + self.assertEqual( + control_insight["data"], + [100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0], + ) + self.assertEqual( + test_insight["data"], + [0.0, 50.0, 125.0, 125.0, 125.0, 205.0, 205.0, 205.0, 205.0, 205.0], + ) + def test_query_runner_with_invalid_data_warehouse_table_name(self): # parquet file isn't created, so we'll get an error table_name = "invalid_table_name" @@ -626,12 +736,19 @@ def test_query_runner_with_invalid_data_warehouse_table_name(self): self.assertEqual(str(context.exception), "'invalid_table_name'") + # Uses the same values as test_query_runner_with_data_warehouse_series_avg_amount for easy comparison @freeze_time("2020-01-01T12:00:00Z") def test_query_runner_with_avg_math(self): feature_flag = self.create_feature_flag() experiment = self.create_experiment(feature_flag=feature_flag) - count_query = TrendsQuery(series=[EventsNode(event="$pageview", math="avg")]) + feature_flag_property = f"$feature/{feature_flag.key}" + + count_query = TrendsQuery( + series=[ + EventsNode(event="purchase", math="avg", math_property="amount", math_property_type="event_properties") + ] + ) exposure_query = TrendsQuery(series=[EventsNode(event="$feature_flag_called")]) experiment_query = ExperimentTrendsQuery( @@ -648,9 +765,80 @@ def test_query_runner_with_avg_math(self): query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team ) + # Populate exposure events - same as data warehouse test + for variant, count in [("control", 1), ("test", 3)]: + for i in range(count): + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id=f"user_{variant}_{i}", + properties={feature_flag_property: variant}, + timestamp=datetime(2020, 1, i + 1), + ) + + # Create purchase events with same amounts as data warehouse test + # Control: 1 purchase of 100 + # Test: 3 purchases of 50, 75, and 80 + _create_event( + team=self.team, + event="purchase", + distinct_id="user_control_0", + properties={feature_flag_property: "control", "amount": 100}, + timestamp=datetime(2020, 1, 2), + ) + + _create_event( + team=self.team, + event="purchase", + distinct_id="user_test_1", + properties={feature_flag_property: "test", "amount": 50}, + timestamp=datetime(2020, 1, 2), + ) + _create_event( + team=self.team, + event="purchase", + distinct_id="user_test_2", + properties={feature_flag_property: "test", "amount": 75}, + timestamp=datetime(2020, 1, 3), + ) + _create_event( + team=self.team, + event="purchase", + distinct_id="user_test_3", + properties={feature_flag_property: "test", "amount": 80}, + timestamp=datetime(2020, 1, 6), + ) + + flush_persons_and_events() + prepared_count_query = query_runner.prepared_count_query self.assertEqual(prepared_count_query.series[0].math, "sum") + result = query_runner.calculate() + trend_result = cast(ExperimentTrendsQueryResponse, result) + + self.assertEqual(len(result.variants), 2) + + control_result = next(variant for variant in trend_result.variants if variant.key == "control") + test_result = next(variant for variant in trend_result.variants if variant.key == "test") + + control_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "control") + test_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "test") + + self.assertEqual(control_result.count, 100) + self.assertAlmostEqual(test_result.count, 205) + self.assertEqual(control_result.absolute_exposure, 1) + self.assertEqual(test_result.absolute_exposure, 3) + + self.assertEqual( + control_insight["data"], + [0.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0], + ) + self.assertEqual( + test_insight["data"], + [0.0, 50.0, 125.0, 125.0, 125.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0], + ) + @flaky(max_runs=10, min_passes=1) @freeze_time("2020-01-01T12:00:00Z") def test_query_runner_standard_flow(self):