From e183735243780d18502dedbbf57193b44aa4bd4d Mon Sep 17 00:00:00 2001 From: Daniel Bachhuber Date: Fri, 6 Dec 2024 05:55:23 -0800 Subject: [PATCH 1/5] Support property value math for data warehouse experiments --- .../experiment_trends_query_runner.py | 81 +++++++------- .../test_experiment_trends_query_runner.py | 100 +++++++++++++++++- 2 files changed, 144 insertions(+), 37 deletions(-) diff --git a/posthog/hogql_queries/experiments/experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/experiment_trends_query_runner.py index 06619c4dfeee1..f577bbcce6aa1 100644 --- a/posthog/hogql_queries/experiments/experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/experiment_trends_query_runner.py @@ -104,6 +104,22 @@ def _get_data_warehouse_breakdown_filter(self) -> BreakdownFilter: breakdown_type="data_warehouse", ) + def _get_data_warehouse_properties(self) -> list[DataWarehousePropertyFilter]: + return [ + DataWarehousePropertyFilter( + key="events.event", + value="$feature_flag_called", + operator=PropertyOperator.EXACT, + type="data_warehouse", + ), + DataWarehousePropertyFilter( + key=f"events.properties.{self.breakdown_key}", + value=self.variants, + operator=PropertyOperator.EXACT, + type="data_warehouse", + ), + ] + def _prepare_count_query(self) -> TrendsQuery: """ This method takes the raw trend query and adapts it @@ -129,20 +145,7 @@ def _prepare_count_query(self) -> TrendsQuery: prepared_count_query.dateRange = self._get_insight_date_range() if self._is_data_warehouse_query(prepared_count_query): prepared_count_query.breakdownFilter = self._get_data_warehouse_breakdown_filter() - prepared_count_query.properties = [ - DataWarehousePropertyFilter( - key="events.event", - value="$feature_flag_called", - operator=PropertyOperator.EXACT, - type="data_warehouse", - ), - DataWarehousePropertyFilter( - key=f"events.properties.{self.breakdown_key}", - value=self.variants, - operator=PropertyOperator.EXACT, - type="data_warehouse", - ), - ] + prepared_count_query.properties = self._get_data_warehouse_properties() else: prepared_count_query.breakdownFilter = self._get_event_breakdown_filter() prepared_count_query.properties = [ @@ -174,30 +177,36 @@ def _prepare_exposure_query(self) -> TrendsQuery: if uses_math_aggregation: prepared_exposure_query = TrendsQuery(**self.query.count_query.model_dump()) - count_event = self.query.count_query.series[0] + prepared_exposure_query.dateRange = self._get_insight_date_range() + prepared_exposure_query.trendsFilter = TrendsFilter(display=ChartDisplayType.ACTIONS_LINE_GRAPH_CUMULATIVE) - if hasattr(count_event, "event"): - prepared_exposure_query.dateRange = self._get_insight_date_range() - prepared_exposure_query.breakdownFilter = self._get_event_breakdown_filter() - prepared_exposure_query.trendsFilter = TrendsFilter( - display=ChartDisplayType.ACTIONS_LINE_GRAPH_CUMULATIVE - ) - prepared_exposure_query.series = [ - EventsNode( - event=count_event.event, - math=BaseMathType.DAU, - ) - ] - prepared_exposure_query.properties = [ - EventPropertyFilter( - key=self.breakdown_key, - value=self.variants, - operator=PropertyOperator.EXACT, - type="event", - ) - ] + # For a data warehouse query, we can use the unique users for the series + if self._is_data_warehouse_query(prepared_exposure_query): + prepared_exposure_query.breakdownFilter = self._get_data_warehouse_breakdown_filter() + prepared_exposure_query.series[0].math = BaseMathType.DAU + prepared_exposure_query.series[0].math_property = None + prepared_exposure_query.series[0].math_property_type = None + prepared_exposure_query.properties = self._get_data_warehouse_properties() else: - raise ValueError("Expected first series item to have an 'event' attribute") + count_event = self.query.count_query.series[0] + if hasattr(count_event, "event"): + prepared_exposure_query.breakdownFilter = self._get_event_breakdown_filter() + prepared_exposure_query.series = [ + EventsNode( + event=count_event.event, + math=BaseMathType.DAU, + ) + ] + prepared_exposure_query.properties = [ + EventPropertyFilter( + key=self.breakdown_key, + value=self.variants, + operator=PropertyOperator.EXACT, + type="event", + ) + ] + else: + raise ValueError("Expected first series item to have an 'event' attribute") # 2. Otherwise, if an exposure query is provided, we use it as is, adapting the date range and breakdown elif self.query.exposure_query: diff --git a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py index 4402afde55eec..545925ab214b1 100644 --- a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py @@ -489,7 +489,7 @@ def test_query_runner_with_holdout(self): self.assertEqual(test_result.absolute_exposure, 9) self.assertEqual(holdout_result.absolute_exposure, 4) - def test_query_runner_with_data_warehouse_series(self): + def test_query_runner_with_data_warehouse_series_total_count(self): table_name = self.create_data_warehouse_table_with_payments() feature_flag = self.create_feature_flag() @@ -509,6 +509,7 @@ def test_query_runner_with_data_warehouse_series(self): id_field="id", table_name=table_name, timestamp_field="dw_timestamp", + math="total", ) ] ) @@ -583,6 +584,103 @@ def test_query_runner_with_data_warehouse_series(self): self.assertEqual(control_result.absolute_exposure, 9) self.assertEqual(test_result.absolute_exposure, 9) + def test_query_runner_with_data_warehouse_series_avg_amount(self): + table_name = self.create_data_warehouse_table_with_payments() + + feature_flag = self.create_feature_flag() + experiment = self.create_experiment( + feature_flag=feature_flag, + start_date=datetime(2023, 1, 1), + end_date=datetime(2023, 1, 10), + ) + + feature_flag_property = f"$feature/{feature_flag.key}" + + count_query = TrendsQuery( + series=[ + DataWarehouseNode( + id=table_name, + distinct_id_field="dw_distinct_id", + id_field="id", + table_name=table_name, + timestamp_field="dw_timestamp", + math="avg", + math_property="amount", + math_property_type="data_warehouse_properties", + ) + ] + ) + exposure_query = TrendsQuery(series=[EventsNode(event="$feature_flag_called")]) + + experiment_query = ExperimentTrendsQuery( + experiment_id=experiment.id, + kind="ExperimentTrendsQuery", + count_query=count_query, + exposure_query=exposure_query, + ) + + experiment.metrics = [{"type": "primary", "query": experiment_query.model_dump()}] + experiment.save() + + # Populate exposure events + for variant, count in [("control", 7), ("test", 9)]: + for i in range(count): + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id=f"user_{variant}_{i}", + properties={feature_flag_property: variant}, + timestamp=datetime(2023, 1, i + 1), + ) + + # "user_test_3" first exposure (feature_flag_property="control") is on 2023-01-03 + # "user_test_3" relevant exposure (feature_flag_property="test") is on 2023-01-04 + # "user_test_3" other event (feature_flag_property="control" is on 2023-01-05 + # "user_test_3" purchase is on 2023-01-06 + # "user_test_3" second exposure (feature_flag_property="control") is on 2023-01-09 + # "user_test_3" should fall into the "test" variant, not the "control" variant + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id="user_test_3", + properties={feature_flag_property: "control"}, + timestamp=datetime(2023, 1, 3), + ) + _create_event( + team=self.team, + event="Some other event", + distinct_id="user_test_3", + properties={feature_flag_property: "control"}, + timestamp=datetime(2023, 1, 5), + ) + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id="user_test_3", + properties={feature_flag_property: "control"}, + timestamp=datetime(2023, 1, 9), + ) + + flush_persons_and_events() + + query_runner = ExperimentTrendsQueryRunner( + query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team + ) + with freeze_time("2023-01-07"): + result = query_runner.calculate() + + trend_result = cast(ExperimentTrendsQueryResponse, result) + + self.assertEqual(len(result.variants), 2) + + control_result = next(variant for variant in trend_result.variants if variant.key == "control") + test_result = next(variant for variant in trend_result.variants if variant.key == "test") + + self.assertEqual(control_result.count, 100) + self.assertEqual(test_result.count, 205) + self.assertEqual(control_result.absolute_exposure, 1) + self.assertEqual(test_result.absolute_exposure, 3) + def test_query_runner_with_invalid_data_warehouse_table_name(self): # parquet file isn't created, so we'll get an error table_name = "invalid_table_name" From bd96af25f6320a2e62d55dc943c893bcf3c94480 Mon Sep 17 00:00:00 2001 From: Daniel Bachhuber Date: Fri, 6 Dec 2024 08:43:41 -0800 Subject: [PATCH 2/5] Avoid type issues by simply defining these --- .../experiment_trends_query_runner.py | 46 +++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/posthog/hogql_queries/experiments/experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/experiment_trends_query_runner.py index f577bbcce6aa1..65d4fd179d1ec 100644 --- a/posthog/hogql_queries/experiments/experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/experiment_trends_query_runner.py @@ -104,22 +104,6 @@ def _get_data_warehouse_breakdown_filter(self) -> BreakdownFilter: breakdown_type="data_warehouse", ) - def _get_data_warehouse_properties(self) -> list[DataWarehousePropertyFilter]: - return [ - DataWarehousePropertyFilter( - key="events.event", - value="$feature_flag_called", - operator=PropertyOperator.EXACT, - type="data_warehouse", - ), - DataWarehousePropertyFilter( - key=f"events.properties.{self.breakdown_key}", - value=self.variants, - operator=PropertyOperator.EXACT, - type="data_warehouse", - ), - ] - def _prepare_count_query(self) -> TrendsQuery: """ This method takes the raw trend query and adapts it @@ -145,7 +129,20 @@ def _prepare_count_query(self) -> TrendsQuery: prepared_count_query.dateRange = self._get_insight_date_range() if self._is_data_warehouse_query(prepared_count_query): prepared_count_query.breakdownFilter = self._get_data_warehouse_breakdown_filter() - prepared_count_query.properties = self._get_data_warehouse_properties() + prepared_count_query.properties = [ + DataWarehousePropertyFilter( + key="events.event", + value="$feature_flag_called", + operator=PropertyOperator.EXACT, + type="data_warehouse", + ), + DataWarehousePropertyFilter( + key=f"events.properties.{self.breakdown_key}", + value=self.variants, + operator=PropertyOperator.EXACT, + type="data_warehouse", + ), + ] else: prepared_count_query.breakdownFilter = self._get_event_breakdown_filter() prepared_count_query.properties = [ @@ -186,7 +183,20 @@ def _prepare_exposure_query(self) -> TrendsQuery: prepared_exposure_query.series[0].math = BaseMathType.DAU prepared_exposure_query.series[0].math_property = None prepared_exposure_query.series[0].math_property_type = None - prepared_exposure_query.properties = self._get_data_warehouse_properties() + prepared_exposure_query.properties = [ + DataWarehousePropertyFilter( + key="events.event", + value="$feature_flag_called", + operator=PropertyOperator.EXACT, + type="data_warehouse", + ), + DataWarehousePropertyFilter( + key=f"events.properties.{self.breakdown_key}", + value=self.variants, + operator=PropertyOperator.EXACT, + type="data_warehouse", + ), + ] else: count_event = self.query.count_query.series[0] if hasattr(count_event, "event"): From c06ff16a8672e579da3cb726acd06d9022c3c02c Mon Sep 17 00:00:00 2001 From: Daniel Bachhuber Date: Fri, 6 Dec 2024 12:11:33 -0800 Subject: [PATCH 3/5] Add an assertion for the data trend --- .../test/test_experiment_trends_query_runner.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py index 545925ab214b1..8378ee851349f 100644 --- a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py @@ -676,11 +676,23 @@ def test_query_runner_with_data_warehouse_series_avg_amount(self): control_result = next(variant for variant in trend_result.variants if variant.key == "control") test_result = next(variant for variant in trend_result.variants if variant.key == "test") + control_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "control") + test_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "test") + self.assertEqual(control_result.count, 100) self.assertEqual(test_result.count, 205) self.assertEqual(control_result.absolute_exposure, 1) self.assertEqual(test_result.absolute_exposure, 3) + self.assertEqual( + control_insight["data"], + [100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0], + ) + self.assertEqual( + test_insight["data"], + [0.0, 50.0, 125.0, 125.0, 125.0, 205.0, 205.0, 205.0, 205.0, 205.0], + ) + def test_query_runner_with_invalid_data_warehouse_table_name(self): # parquet file isn't created, so we'll get an error table_name = "invalid_table_name" From 1810b73079c67ea5693c748f0d3213b0f6e999ad Mon Sep 17 00:00:00 2001 From: Daniel Bachhuber Date: Fri, 6 Dec 2024 12:15:28 -0800 Subject: [PATCH 4/5] Expand test to serve as a point of comparison --- .../test_experiment_trends_query_runner.py | 55 ++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py index 8378ee851349f..c9aa6753f2edd 100644 --- a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py @@ -741,7 +741,13 @@ def test_query_runner_with_avg_math(self): feature_flag = self.create_feature_flag() experiment = self.create_experiment(feature_flag=feature_flag) - count_query = TrendsQuery(series=[EventsNode(event="$pageview", math="avg")]) + feature_flag_property = f"$feature/{feature_flag.key}" + + count_query = TrendsQuery( + series=[ + EventsNode(event="purchase", math="avg", math_property="amount", math_property_type="event_properties") + ] + ) exposure_query = TrendsQuery(series=[EventsNode(event="$feature_flag_called")]) experiment_query = ExperimentTrendsQuery( @@ -758,9 +764,56 @@ def test_query_runner_with_avg_math(self): query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team ) + for variant, count in [("control", 7), ("test", 9)]: + for i in range(count): + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id=f"user_{variant}_{i}", + properties={feature_flag_property: variant}, + timestamp=datetime(2020, 1, i + 1), + ) + + for variant, count in [("control", 4), ("test", 2)]: + for i in range(count): + _create_event( + team=self.team, + event="purchase", + distinct_id=f"user_{variant}_{i}", + properties={feature_flag_property: variant, "amount": i * 10}, + timestamp=datetime(2020, 1, i + 2), + ) + + flush_persons_and_events() + prepared_count_query = query_runner.prepared_count_query self.assertEqual(prepared_count_query.series[0].math, "sum") + result = query_runner.calculate() + trend_result = cast(ExperimentTrendsQueryResponse, result) + + self.assertEqual(len(result.variants), 2) + + control_result = next(variant for variant in trend_result.variants if variant.key == "control") + test_result = next(variant for variant in trend_result.variants if variant.key == "test") + + control_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "control") + test_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "test") + + self.assertEqual(control_result.count, 60) + self.assertEqual(test_result.count, 10) + self.assertEqual(control_result.absolute_exposure, 4) + self.assertEqual(test_result.absolute_exposure, 2) + + self.assertEqual( + control_insight["data"], + [0.0, 0.0, 10.0, 30.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0], + ) + self.assertEqual( + test_insight["data"], + [0.0, 0.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0], + ) + @flaky(max_runs=10, min_passes=1) @freeze_time("2020-01-01T12:00:00Z") def test_query_runner_standard_flow(self): From f95dd7cd136372a3b94a6949d4a73b647f6564da Mon Sep 17 00:00:00 2001 From: Daniel Bachhuber Date: Mon, 9 Dec 2024 16:16:30 -0800 Subject: [PATCH 5/5] Use the same values between the two tests --- .../test_experiment_trends_query_runner.py | 57 +++++++++++++------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py index c9aa6753f2edd..c8596256254d5 100644 --- a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py @@ -736,6 +736,7 @@ def test_query_runner_with_invalid_data_warehouse_table_name(self): self.assertEqual(str(context.exception), "'invalid_table_name'") + # Uses the same values as test_query_runner_with_data_warehouse_series_avg_amount for easy comparison @freeze_time("2020-01-01T12:00:00Z") def test_query_runner_with_avg_math(self): feature_flag = self.create_feature_flag() @@ -764,7 +765,8 @@ def test_query_runner_with_avg_math(self): query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team ) - for variant, count in [("control", 7), ("test", 9)]: + # Populate exposure events - same as data warehouse test + for variant, count in [("control", 1), ("test", 3)]: for i in range(count): _create_event( team=self.team, @@ -774,15 +776,38 @@ def test_query_runner_with_avg_math(self): timestamp=datetime(2020, 1, i + 1), ) - for variant, count in [("control", 4), ("test", 2)]: - for i in range(count): - _create_event( - team=self.team, - event="purchase", - distinct_id=f"user_{variant}_{i}", - properties={feature_flag_property: variant, "amount": i * 10}, - timestamp=datetime(2020, 1, i + 2), - ) + # Create purchase events with same amounts as data warehouse test + # Control: 1 purchase of 100 + # Test: 3 purchases of 50, 75, and 80 + _create_event( + team=self.team, + event="purchase", + distinct_id="user_control_0", + properties={feature_flag_property: "control", "amount": 100}, + timestamp=datetime(2020, 1, 2), + ) + + _create_event( + team=self.team, + event="purchase", + distinct_id="user_test_1", + properties={feature_flag_property: "test", "amount": 50}, + timestamp=datetime(2020, 1, 2), + ) + _create_event( + team=self.team, + event="purchase", + distinct_id="user_test_2", + properties={feature_flag_property: "test", "amount": 75}, + timestamp=datetime(2020, 1, 3), + ) + _create_event( + team=self.team, + event="purchase", + distinct_id="user_test_3", + properties={feature_flag_property: "test", "amount": 80}, + timestamp=datetime(2020, 1, 6), + ) flush_persons_and_events() @@ -800,18 +825,18 @@ def test_query_runner_with_avg_math(self): control_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "control") test_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "test") - self.assertEqual(control_result.count, 60) - self.assertEqual(test_result.count, 10) - self.assertEqual(control_result.absolute_exposure, 4) - self.assertEqual(test_result.absolute_exposure, 2) + self.assertEqual(control_result.count, 100) + self.assertAlmostEqual(test_result.count, 205) + self.assertEqual(control_result.absolute_exposure, 1) + self.assertEqual(test_result.absolute_exposure, 3) self.assertEqual( control_insight["data"], - [0.0, 0.0, 10.0, 30.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0], + [0.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0], ) self.assertEqual( test_insight["data"], - [0.0, 0.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0], + [0.0, 50.0, 125.0, 125.0, 125.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0, 205.0], ) @flaky(max_runs=10, min_passes=1)