From d4cc10ed520bdc1513466ded0bde95902514c0a8 Mon Sep 17 00:00:00 2001 From: Daniel Bachhuber Date: Thu, 12 Dec 2024 16:43:23 -0800 Subject: [PATCH] fix(experiments): Fix a couple of issues in the ASOF LEFT JOIN (#26886) --- .../test_experiment_trends_query_runner.py | 185 ++++++++++++++++++ posthog/warehouse/models/join.py | 5 +- 2 files changed, 188 insertions(+), 2 deletions(-) diff --git a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py index 4d0813c4393d6..b6be3eb76d065 100644 --- a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py @@ -183,6 +183,69 @@ def create_data_warehouse_table_with_payments(self): ) return table_name + def create_data_warehouse_table_with_usage(self): + if not OBJECT_STORAGE_ACCESS_KEY_ID or not OBJECT_STORAGE_SECRET_ACCESS_KEY: + raise Exception("Missing vars") + + fs = s3fs.S3FileSystem( + client_kwargs={ + "region_name": "us-east-1", + "endpoint_url": OBJECT_STORAGE_ENDPOINT, + "aws_access_key_id": OBJECT_STORAGE_ACCESS_KEY_ID, + "aws_secret_access_key": OBJECT_STORAGE_SECRET_ACCESS_KEY, + }, + ) + + path_to_s3_object = "s3://" + OBJECT_STORAGE_BUCKET + f"/{TEST_BUCKET}" + + id = pa.array(["1", "2", "3", "4", "5"]) + date = pa.array(["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-06", "2023-01-07"]) + user_id = pa.array(["user_control_0", "user_test_1", "user_test_2", "user_test_3", "user_extra"]) + usage = pa.array([1000, 500, 750, 800, 900]) + names = ["id", "ds", "userid", "usage"] + + pq.write_to_dataset( + pa.Table.from_arrays([id, date, user_id, usage], names=names), + path_to_s3_object, + filesystem=fs, + use_dictionary=True, + compression="snappy", + version="2.0", + ) + + table_name = "usage" + + credential = DataWarehouseCredential.objects.create( + access_key=OBJECT_STORAGE_ACCESS_KEY_ID, + access_secret=OBJECT_STORAGE_SECRET_ACCESS_KEY, + team=self.team, + ) + + DataWarehouseTable.objects.create( + name=table_name, + url_pattern=f"http://host.docker.internal:19000/{OBJECT_STORAGE_BUCKET}/{TEST_BUCKET}/*.parquet", + format=DataWarehouseTable.TableFormat.Parquet, + team=self.team, + columns={ + "id": "String", + "ds": "Date", + "userid": "String", + "usage": "Int64", + }, + credential=credential, + ) + + DataWarehouseJoin.objects.create( + team=self.team, + source_table_name=table_name, + source_table_key="userid", + joining_table_name="events", + joining_table_key="properties.$user_id", + field_name="events", + configuration={"experiments_optimized": True, "experiments_timestamp_key": "ds"}, + ) + return table_name + @freeze_time("2020-01-01T12:00:00Z") def test_query_runner(self): feature_flag = self.create_feature_flag() @@ -694,6 +757,128 @@ def test_query_runner_with_data_warehouse_series_avg_amount(self): [0.0, 50.0, 125.0, 125.0, 125.0, 205.0, 205.0, 205.0, 205.0, 205.0], ) + def test_query_runner_with_data_warehouse_series_no_end_date_and_nested_id(self): + table_name = self.create_data_warehouse_table_with_usage() + + feature_flag = self.create_feature_flag() + experiment = self.create_experiment( + feature_flag=feature_flag, + start_date=datetime(2023, 1, 1), + ) + + feature_flag_property = f"$feature/{feature_flag.key}" + + count_query = TrendsQuery( + series=[ + DataWarehouseNode( + id=table_name, + distinct_id_field="userid", + id_field="id", + table_name=table_name, + timestamp_field="ds", + math="avg", + math_property="usage", + math_property_type="data_warehouse_properties", + ) + ] + ) + exposure_query = TrendsQuery(series=[EventsNode(event="$feature_flag_called")]) + + experiment_query = ExperimentTrendsQuery( + experiment_id=experiment.id, + kind="ExperimentTrendsQuery", + count_query=count_query, + exposure_query=exposure_query, + ) + + experiment.metrics = [{"type": "primary", "query": experiment_query.model_dump()}] + experiment.save() + + # Populate exposure events + for variant, count in [("control", 7), ("test", 9)]: + for i in range(count): + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id=f"distinct_{variant}_{i}", + properties={feature_flag_property: variant, "$user_id": f"user_{variant}_{i}"}, + timestamp=datetime(2023, 1, i + 1), + ) + + # "user_test_3" first exposure (feature_flag_property="control") is on 2023-01-03 + # "user_test_3" relevant exposure (feature_flag_property="test") is on 2023-01-04 + # "user_test_3" other event (feature_flag_property="control" is on 2023-01-05 + # "user_test_3" purchase is on 2023-01-06 + # "user_test_3" second exposure (feature_flag_property="control") is on 2023-01-09 + # "user_test_3" should fall into the "test" variant, not the "control" variant + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id="distinct_test_3", + properties={feature_flag_property: "control", "$user_id": "user_test_3"}, + timestamp=datetime(2023, 1, 3), + ) + _create_event( + team=self.team, + event="Some other event", + distinct_id="distinct_test_3", + properties={feature_flag_property: "control", "$user_id": "user_test_3"}, + timestamp=datetime(2023, 1, 5), + ) + _create_event( + team=self.team, + event="$feature_flag_called", + distinct_id="distinct_test_3", + properties={feature_flag_property: "control", "$user_id": "user_test_3"}, + timestamp=datetime(2023, 1, 9), + ) + + flush_persons_and_events() + + query_runner = ExperimentTrendsQueryRunner( + query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team + ) + with freeze_time("2023-01-07"): + # Build and execute the query to get the ClickHouse SQL + queries = query_runner.count_query_runner.to_queries() + response = execute_hogql_query( + query_type="TrendsQuery", + query=queries[0], + team=query_runner.count_query_runner.team, + modifiers=query_runner.count_query_runner.modifiers, + limit_context=query_runner.count_query_runner.limit_context, + ) + + # Assert the expected join condition in the clickhouse SQL + expected_join_condition = f"and(equals(events.team_id, {query_runner.count_query_runner.team.id}), equals(event, %(hogql_val_8)s), greaterOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_9)s, 6, %(hogql_val_10)s))), lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_11)s, 6, %(hogql_val_12)s))))) AS e__events ON" + self.assertIn(expected_join_condition, str(response.clickhouse)) + + result = query_runner.calculate() + + trend_result = cast(ExperimentTrendsQueryResponse, result) + + self.assertEqual(len(result.variants), 2) + + control_result = next(variant for variant in trend_result.variants if variant.key == "control") + test_result = next(variant for variant in trend_result.variants if variant.key == "test") + + control_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "control") + test_insight = next(variant for variant in trend_result.insight if variant["breakdown_value"] == "test") + + self.assertEqual(control_result.count, 1000) + self.assertEqual(test_result.count, 2050) + self.assertEqual(control_result.absolute_exposure, 1) + self.assertEqual(test_result.absolute_exposure, 3) + + self.assertEqual( + control_insight["data"][:10], + [1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0], + ) + self.assertEqual( + test_insight["data"][:10], + [0.0, 500.0, 1250.0, 1250.0, 1250.0, 2050.0, 2050.0, 2050.0, 2050.0, 2050.0], + ) + def test_query_runner_with_data_warehouse_series_expected_query(self): table_name = self.create_data_warehouse_table_with_payments() diff --git a/posthog/warehouse/models/join.py b/posthog/warehouse/models/join.py index 36940434476ce..51108c58e578c 100644 --- a/posthog/warehouse/models/join.py +++ b/posthog/warehouse/models/join.py @@ -126,7 +126,8 @@ def _join_function_for_experiments( for expr in node.where.exprs: if isinstance(expr, ast.CompareOperation): if expr.op == ast.CompareOperationOp.GtEq or expr.op == ast.CompareOperationOp.LtEq: - if isinstance(expr.left, ast.Alias) and expr.left.expr.to_hogql() == timestamp_key: + # Match within hogql string because it could be 'toDateTime(timestamp)' + if isinstance(expr.left, ast.Alias) and timestamp_key in expr.left.expr.to_hogql(): whereExpr.append( ast.CompareOperation( op=expr.op, left=ast.Field(chain=["timestamp"]), right=expr.right @@ -183,7 +184,7 @@ def _join_function_for_experiments( ] ), op=ast.CompareOperationOp.Eq, - right=ast.Field(chain=[join_to_add.to_table, "distinct_id"]), + right=ast.Field(chain=[join_to_add.to_table, *self.joining_table_key.split(".")]), ), ast.CompareOperation( left=ast.Field(