Skip to content

Commit

Permalink
feat(data-warehouse): Support for Experiment-optimized JOINs (#26446)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
danielbachhuber and github-actions[bot] authored Dec 2, 2024
1 parent 5960ef8 commit dd9f827
Show file tree
Hide file tree
Showing 15 changed files with 374 additions and 157 deletions.
7 changes: 6 additions & 1 deletion frontend/src/lib/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2356,7 +2356,12 @@ const api = {
viewId: DataWarehouseViewLink['id'],
data: Pick<
DataWarehouseViewLink,
'source_table_name' | 'source_table_key' | 'joining_table_name' | 'joining_table_key' | 'field_name'
| 'source_table_name'
| 'source_table_key'
| 'joining_table_name'
| 'joining_table_key'
| 'field_name'
| 'configuration'
>
): Promise<DataWarehouseViewLink> {
return await new ApiRequest().dataWarehouseViewLink(viewId).update({ data })
Expand Down
36 changes: 36 additions & 0 deletions frontend/src/scenes/data-warehouse/ViewLinkModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import './ViewLinkModal.scss'
import { IconCollapse, IconExpand } from '@posthog/icons'
import {
LemonButton,
LemonCheckbox,
LemonDivider,
LemonDropdown,
LemonInput,
Expand Down Expand Up @@ -58,6 +59,8 @@ export function ViewLinkForm(): JSX.Element {
sourceIsUsingHogQLExpression,
joiningIsUsingHogQLExpression,
isViewLinkSubmitting,
experimentsOptimized,
experimentsTimestampKey,
} = useValues(viewLinkLogic)
const {
selectJoiningTable,
Expand All @@ -66,6 +69,8 @@ export function ViewLinkForm(): JSX.Element {
setFieldName,
selectSourceKey,
selectJoiningKey,
setExperimentsOptimized,
selectExperimentsTimestampKey,
} = useActions(viewLinkLogic)
const [advancedSettingsExpanded, setAdvancedSettingsExpanded] = useState(false)

Expand Down Expand Up @@ -151,6 +156,37 @@ export function ViewLinkForm(): JSX.Element {
</Field>
</div>
</div>
{'events' === selectedJoiningTableName && (
<div className="w-full mt-2">
<LemonDivider className="mt-4 mb-4" />
<div className="mt-4 flex flex-row justify-between w-full">
<div className="mr-4">
<span className="l4">Optimize for Experiments</span>
<Field name="experiments_optimized">
<LemonCheckbox
className="mt-2"
checked={experimentsOptimized}
onChange={(checked) => setExperimentsOptimized(checked)}
fullWidth
label="Limit join to most recent matching event based on&nbsp;timestamp"
/>
</Field>
</div>
<div className="w-60 shrink-0">
<span className="l4">Source Timestamp Key</span>
<Field name="experiments_timestamp_key">
<LemonSelect
fullWidth
onSelect={selectExperimentsTimestampKey}
value={experimentsTimestampKey ?? undefined}
options={sourceTableKeys}
placeholder="Select a key"
/>
</Field>
</div>
</div>
</div>
)}
{sqlCodeSnippet && (
<div className="w-full mt-2">
<LemonDivider className="mt-4 mb-4" />
Expand Down
36 changes: 36 additions & 0 deletions frontend/src/scenes/data-warehouse/viewLinkLogic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
deleteViewLink: (table, column) => ({ table, column }),
setError: (error: string) => ({ error }),
setFieldName: (fieldName: string) => ({ fieldName }),
setExperimentsOptimized: (experimentsOptimized: boolean) => ({ experimentsOptimized }),
selectExperimentsTimestampKey: (experimentsTimestampKey: string | null) => ({ experimentsTimestampKey }),
clearModalFields: true,
})),
reducers({
Expand Down Expand Up @@ -101,6 +103,22 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
clearModalFields: () => '',
},
],
experimentsOptimized: [
false as boolean,
{
setExperimentsOptimized: (_, { experimentsOptimized }) => experimentsOptimized,
toggleEditJoinModal: (_, { join }) => join.configuration?.experiments_optimized ?? false,
clearModalFields: () => false,
},
],
experimentsTimestampKey: [
null as string | null,
{
selectExperimentsTimestampKey: (_, { experimentsTimestampKey }) => experimentsTimestampKey,
toggleEditJoinModal: (_, { join }) => join.configuration?.experiments_timestamp_key ?? null,
clearModalFields: () => null,
},
],
isJoinTableModalOpen: [
false,
{
Expand Down Expand Up @@ -136,6 +154,10 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
joining_table_name,
joining_table_key: values.selectedJoiningKey ?? undefined,
field_name: values.fieldName,
configuration: {
experiments_optimized: values.experimentsOptimized,
experiments_timestamp_key: values.experimentsTimestampKey ?? undefined,
},
})

actions.toggleJoinTableModal()
Expand All @@ -156,6 +178,10 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
joining_table_name,
joining_table_key: values.selectedJoiningKey ?? undefined,
field_name: values.fieldName,
configuration: {
experiments_optimized: values.experimentsOptimized,
experiments_timestamp_key: values.experimentsTimestampKey ?? undefined,
},
})

actions.toggleJoinTableModal()
Expand All @@ -175,6 +201,16 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
toggleEditJoinModal: ({ join }) => {
actions.setViewLinkValues(join)
},
setExperimentsOptimized: ({ experimentsOptimized }) => {
if (!experimentsOptimized) {
actions.selectExperimentsTimestampKey(null)
}
},
selectExperimentsTimestampKey: ({ experimentsTimestampKey }) => {
if (experimentsTimestampKey) {
actions.setExperimentsOptimized(true)
}
},
})),
selectors({
selectedSourceTable: [
Expand Down
4 changes: 4 additions & 0 deletions frontend/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4053,6 +4053,10 @@ export interface DataWarehouseViewLink {
field_name?: string
created_by?: UserBasicType | null
created_at?: string | null
configuration?: {
experiments_optimized?: boolean
experiments_timestamp_key?: string | null
}
}

export enum DataWarehouseSettingsTab {
Expand Down
4 changes: 3 additions & 1 deletion posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def define_mappings(warehouse: dict[str, Table], get_table: Callable):
from_field=from_field,
to_field=to_field,
join_table=joining_table,
join_function=join.join_function(),
join_function=join.join_function_for_experiments()
if "events" == join.joining_table_name and join.configuration.get("experiments_optimized")
else join.join_function(),
)

if join.source_table_name == "persons":
Expand Down
7 changes: 2 additions & 5 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import Table, FunctionCallTable, SavedQuery
from posthog.hogql.database.database import Database, create_hogql_database
from posthog.hogql.database.database import create_hogql_database
from posthog.hogql.database.s3_table import S3Table
from posthog.hogql.errors import ImpossibleASTError, InternalHogQLError, QueryError, ResolutionError
from posthog.hogql.escape_sql import (
Expand Down Expand Up @@ -66,9 +66,7 @@ def team_id_guard_for_table(table_type: Union[ast.TableType, ast.TableAliasType]
)


def to_printed_hogql(
query: ast.Expr, team: Team, modifiers: Optional[HogQLQueryModifiers] = None, database: Optional["Database"] = None
) -> str:
def to_printed_hogql(query: ast.Expr, team: Team, modifiers: Optional[HogQLQueryModifiers] = None) -> str:
"""Prints the HogQL query without mutating the node"""
return print_ast(
clone_expr(query),
Expand All @@ -77,7 +75,6 @@ def to_printed_hogql(
team_id=team.pk,
enable_select_queries=True,
modifiers=create_default_modifiers_for_team(team, modifiers),
database=database,
),
pretty=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
from django.conf import settings
from posthog.constants import ExperimentNoResultsErrorKeys
from posthog.hogql import ast
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.database import create_hogql_database
from posthog.hogql.database.models import LazyJoin
from posthog.hogql_queries.experiments import CONTROL_VARIANT_KEY
from posthog.hogql_queries.experiments.trends_statistics import (
are_results_significant,
Expand Down Expand Up @@ -37,7 +34,7 @@
TrendsQuery,
TrendsQueryResponse,
)
from typing import Any, Optional, cast
from typing import Any, Optional
import threading


Expand Down Expand Up @@ -255,86 +252,7 @@ def calculate(self) -> ExperimentTrendsQueryResponse:

def run(query_runner: TrendsQueryRunner, result_key: str, is_parallel: bool):
try:
# Create a new database instance where we can attach our
# custom join to the events table. It will be passed through
# and used by the query runner.
database = create_hogql_database(team_id=self.team.pk)
if self._is_data_warehouse_query(query_runner.query):
series_node = cast(DataWarehouseNode, query_runner.query.series[0])
table = database.get_table(series_node.table_name)
table.fields["events"] = LazyJoin(
from_field=[series_node.distinct_id_field],
join_table=database.get_table("events"),
join_function=lambda join_to_add, context, node: (
ast.JoinExpr(
table=ast.SelectQuery(
select=[
ast.Alias(alias=name, expr=ast.Field(chain=["events", *chain]))
for name, chain in {
**join_to_add.fields_accessed,
"timestamp": ["timestamp"],
"distinct_id": ["distinct_id"],
"properties": ["properties"],
}.items()
],
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
),
# ASOF JOIN finds the most recent matching event that occurred at or before each data warehouse timestamp.
#
# Why this matters:
# When a user performs an action (recorded in data warehouse), we want to know which
# experiment variant they were assigned at that moment. The most recent $feature_flag_called
# event before their action represents their active variant assignment.
#
# Example:
# Data Warehouse: timestamp=2024-01-03 12:00, distinct_id=user1
# Events:
# 2024-01-02: (user1, variant='control') <- This event will be joined
# 2024-01-03: (user1, variant='test') <- Ignored (occurs after data warehouse timestamp)
#
# This ensures we capture the correct causal relationship: which experiment variant
# was the user assigned to when they performed the action?
join_type="ASOF LEFT JOIN",
alias=join_to_add.to_table,
constraint=ast.JoinConstraint(
expr=ast.And(
exprs=[
ast.CompareOperation(
left=ast.Field(chain=[join_to_add.to_table, "event"]),
op=ast.CompareOperationOp.Eq,
right=ast.Constant(value="$feature_flag_called"),
),
ast.CompareOperation(
left=ast.Field(
chain=[
join_to_add.from_table,
series_node.distinct_id_field,
]
),
op=ast.CompareOperationOp.Eq,
right=ast.Field(chain=[join_to_add.to_table, "distinct_id"]),
),
ast.CompareOperation(
left=ast.Field(
chain=[
join_to_add.from_table,
series_node.timestamp_field,
]
),
op=ast.CompareOperationOp.GtEq,
right=ast.Field(chain=[join_to_add.to_table, "timestamp"]),
),
]
),
constraint_type="ON",
),
)
),
)

context = HogQLContext(team_id=self.team.pk, database=database)

result = query_runner.calculate(context=context)
result = query_runner.calculate()
shared_results[result_key] = result
except Exception as e:
errors.append(e)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from django.test import override_settings
from posthog.hogql.errors import QueryError
from posthog.hogql_queries.experiments.experiment_trends_query_runner import ExperimentTrendsQueryRunner
from posthog.models.experiment import Experiment, ExperimentHoldout
from posthog.models.feature_flag.feature_flag import FeatureFlag
Expand Down Expand Up @@ -34,6 +33,7 @@
from boto3 import resource
from botocore.config import Config
from posthog.warehouse.models.credential import DataWarehouseCredential
from posthog.warehouse.models.join import DataWarehouseJoin
from posthog.warehouse.models.table import DataWarehouseTable

TEST_BUCKET = "test_storage_bucket-posthog.hogql.datawarehouse.trendquery" + XDIST_SUFFIX
Expand Down Expand Up @@ -137,7 +137,7 @@ def create_data_warehouse_table_with_payments(self):
)
distinct_id = pa.array(["user_control_0", "user_test_1", "user_test_2", "user_test_3", "user_extra"])
amount = pa.array([100, 50, 75, 80, 90])
names = ["id", "timestamp", "distinct_id", "amount"]
names = ["id", "dw_timestamp", "dw_distinct_id", "amount"]

pq.write_to_dataset(
pa.Table.from_arrays([id, timestamp, distinct_id, amount], names=names),
Expand All @@ -163,12 +163,22 @@ def create_data_warehouse_table_with_payments(self):
team=self.team,
columns={
"id": "String",
"timestamp": "DateTime64(3, 'UTC')",
"distinct_id": "String",
"dw_timestamp": "DateTime64(3, 'UTC')",
"dw_distinct_id": "String",
"amount": "Int64",
},
credential=credential,
)

DataWarehouseJoin.objects.create(
team=self.team,
source_table_name=table_name,
source_table_key="dw_distinct_id",
joining_table_name="events",
joining_table_key="distinct_id",
field_name="events",
configuration={"experiments_optimized": True, "experiments_timestamp_key": "dw_timestamp"},
)
return table_name

@freeze_time("2020-01-01T12:00:00Z")
Expand Down Expand Up @@ -494,10 +504,10 @@ def test_query_runner_with_data_warehouse_series(self):
series=[
DataWarehouseNode(
id=table_name,
distinct_id_field="distinct_id",
id_field="distinct_id",
distinct_id_field="dw_distinct_id",
id_field="id",
table_name=table_name,
timestamp_field="timestamp",
timestamp_field="dw_timestamp",
)
]
)
Expand Down Expand Up @@ -587,10 +597,10 @@ def test_query_runner_with_invalid_data_warehouse_table_name(self):
series=[
DataWarehouseNode(
id=table_name,
distinct_id_field="distinct_id",
id_field="distinct_id",
distinct_id_field="dw_distinct_id",
id_field="id",
table_name=table_name,
timestamp_field="timestamp",
timestamp_field="dw_timestamp",
)
]
)
Expand All @@ -610,10 +620,10 @@ def test_query_runner_with_invalid_data_warehouse_table_name(self):
query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team
)
with freeze_time("2023-01-07"):
with self.assertRaises(QueryError) as context:
with self.assertRaises(KeyError) as context:
query_runner.calculate()

self.assertEqual(str(context.exception), 'Unknown table "invalid_table_name".')
self.assertEqual(str(context.exception), "'invalid_table_name'")

@freeze_time("2020-01-01T12:00:00Z")
def test_query_runner_with_avg_math(self):
Expand Down
Loading

0 comments on commit dd9f827

Please sign in to comment.