Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(data-warehouse): Support for Experiment-optimized JOINs #26446

Merged
merged 23 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
92a305f
First pass at backend for optimized join tables
danielbachhuber Nov 26, 2024
855178f
Simplify key name
danielbachhuber Nov 26, 2024
1003d50
Add a checkbox for indicating the JOIN should be optimized
danielbachhuber Nov 26, 2024
3e28530
Update exception catching
danielbachhuber Nov 26, 2024
d088921
Update query snapshots
github-actions[bot] Nov 26, 2024
7a91ba4
Recreate the migration
danielbachhuber Nov 30, 2024
57d5ff5
Revert changes to test_session_recordings
danielbachhuber Nov 30, 2024
76fa3aa
Update query snapshots
github-actions[bot] Nov 30, 2024
d4373de
Optional
danielbachhuber Nov 30, 2024
baf1c99
Add some tests
danielbachhuber Nov 30, 2024
37e72d5
Ensure modal state is set when loaded
danielbachhuber Nov 30, 2024
1e62360
Allow setting experiment timestamp field
danielbachhuber Nov 30, 2024
c3bbd85
Actually need the source table keys
danielbachhuber Nov 30, 2024
3cb7aea
Toggle checkbox if column is selected
danielbachhuber Nov 30, 2024
16d6c09
Use stored fields in the actual JOIN
danielbachhuber Nov 30, 2024
917719e
Update query snapshots
github-actions[bot] Nov 30, 2024
8f80be7
Resolve mypy error
danielbachhuber Nov 30, 2024
d190a12
Uppercase
danielbachhuber Nov 30, 2024
413a63f
Rename to `experiments_timestamp_key`
danielbachhuber Nov 30, 2024
e9c8d8d
More safe-guarding
danielbachhuber Nov 30, 2024
d634847
Merge branch 'master' into experiments/optimized-dw-configuration
danielbachhuber Nov 30, 2024
7b7d89a
Merge branch 'master' into experiments/optimized-dw-configuration
danielbachhuber Dec 2, 2024
bcd7200
Update query snapshots
github-actions[bot] Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion frontend/src/lib/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2356,7 +2356,12 @@ const api = {
viewId: DataWarehouseViewLink['id'],
data: Pick<
DataWarehouseViewLink,
'source_table_name' | 'source_table_key' | 'joining_table_name' | 'joining_table_key' | 'field_name'
| 'source_table_name'
| 'source_table_key'
| 'joining_table_name'
| 'joining_table_key'
| 'field_name'
| 'configuration'
>
): Promise<DataWarehouseViewLink> {
return await new ApiRequest().dataWarehouseViewLink(viewId).update({ data })
Expand Down
36 changes: 36 additions & 0 deletions frontend/src/scenes/data-warehouse/ViewLinkModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import './ViewLinkModal.scss'
import { IconCollapse, IconExpand } from '@posthog/icons'
import {
LemonButton,
LemonCheckbox,
LemonDivider,
LemonDropdown,
LemonInput,
Expand Down Expand Up @@ -58,6 +59,8 @@ export function ViewLinkForm(): JSX.Element {
sourceIsUsingHogQLExpression,
joiningIsUsingHogQLExpression,
isViewLinkSubmitting,
experimentsOptimized,
experimentsTimestampKey,
} = useValues(viewLinkLogic)
const {
selectJoiningTable,
Expand All @@ -66,6 +69,8 @@ export function ViewLinkForm(): JSX.Element {
setFieldName,
selectSourceKey,
selectJoiningKey,
setExperimentsOptimized,
selectExperimentsTimestampKey,
} = useActions(viewLinkLogic)
const [advancedSettingsExpanded, setAdvancedSettingsExpanded] = useState(false)

Expand Down Expand Up @@ -151,6 +156,37 @@ export function ViewLinkForm(): JSX.Element {
</Field>
</div>
</div>
{'events' === selectedJoiningTableName && (
<div className="w-full mt-2">
<LemonDivider className="mt-4 mb-4" />
<div className="mt-4 flex flex-row justify-between w-full">
<div className="mr-4">
<span className="l4">Optimize for Experiments</span>
<Field name="experiments_optimized">
<LemonCheckbox
className="mt-2"
checked={experimentsOptimized}
onChange={(checked) => setExperimentsOptimized(checked)}
fullWidth
label="Limit join to most recent matching event based on&nbsp;timestamp"
/>
</Field>
</div>
<div className="w-60 shrink-0">
<span className="l4">Source Timestamp Key</span>
<Field name="experiments_timestamp_key">
<LemonSelect
fullWidth
onSelect={selectExperimentsTimestampKey}
value={experimentsTimestampKey ?? undefined}
options={sourceTableKeys}
placeholder="Select a key"
/>
</Field>
</div>
</div>
</div>
)}
{sqlCodeSnippet && (
<div className="w-full mt-2">
<LemonDivider className="mt-4 mb-4" />
Expand Down
36 changes: 36 additions & 0 deletions frontend/src/scenes/data-warehouse/viewLinkLogic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
deleteViewLink: (table, column) => ({ table, column }),
setError: (error: string) => ({ error }),
setFieldName: (fieldName: string) => ({ fieldName }),
setExperimentsOptimized: (experimentsOptimized: boolean) => ({ experimentsOptimized }),
selectExperimentsTimestampKey: (experimentsTimestampKey: string | null) => ({ experimentsTimestampKey }),
clearModalFields: true,
})),
reducers({
Expand Down Expand Up @@ -101,6 +103,22 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
clearModalFields: () => '',
},
],
experimentsOptimized: [
false as boolean,
{
setExperimentsOptimized: (_, { experimentsOptimized }) => experimentsOptimized,
toggleEditJoinModal: (_, { join }) => join.configuration?.experiments_optimized ?? false,
clearModalFields: () => false,
},
],
experimentsTimestampKey: [
null as string | null,
{
selectExperimentsTimestampKey: (_, { experimentsTimestampKey }) => experimentsTimestampKey,
toggleEditJoinModal: (_, { join }) => join.configuration?.experiments_timestamp_key ?? null,
clearModalFields: () => null,
},
],
isJoinTableModalOpen: [
false,
{
Expand Down Expand Up @@ -136,6 +154,10 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
joining_table_name,
joining_table_key: values.selectedJoiningKey ?? undefined,
field_name: values.fieldName,
configuration: {
experiments_optimized: values.experimentsOptimized,
experiments_timestamp_key: values.experimentsTimestampKey ?? undefined,
},
})

actions.toggleJoinTableModal()
Expand All @@ -156,6 +178,10 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
joining_table_name,
joining_table_key: values.selectedJoiningKey ?? undefined,
field_name: values.fieldName,
configuration: {
experiments_optimized: values.experimentsOptimized,
experiments_timestamp_key: values.experimentsTimestampKey ?? undefined,
},
})

actions.toggleJoinTableModal()
Expand All @@ -175,6 +201,16 @@ export const viewLinkLogic = kea<viewLinkLogicType>([
toggleEditJoinModal: ({ join }) => {
actions.setViewLinkValues(join)
},
setExperimentsOptimized: ({ experimentsOptimized }) => {
if (!experimentsOptimized) {
actions.selectExperimentsTimestampKey(null)
}
},
selectExperimentsTimestampKey: ({ experimentsTimestampKey }) => {
if (experimentsTimestampKey) {
actions.setExperimentsOptimized(true)
}
},
})),
selectors({
selectedSourceTable: [
Expand Down
4 changes: 4 additions & 0 deletions frontend/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4053,6 +4053,10 @@ export interface DataWarehouseViewLink {
field_name?: string
created_by?: UserBasicType | null
created_at?: string | null
configuration?: {
experiments_optimized?: boolean
experiments_timestamp_key?: string | null
}
}

export enum DataWarehouseSettingsTab {
Expand Down
4 changes: 3 additions & 1 deletion posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def define_mappings(warehouse: dict[str, Table], get_table: Callable):
from_field=from_field,
to_field=to_field,
join_table=joining_table,
join_function=join.join_function(),
join_function=join.join_function_for_experiments()
if "events" == join.joining_table_name and join.configuration.get("experiments_optimized")
else join.join_function(),
)

if join.source_table_name == "persons":
Expand Down
7 changes: 2 additions & 5 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import Table, FunctionCallTable, SavedQuery
from posthog.hogql.database.database import Database, create_hogql_database
from posthog.hogql.database.database import create_hogql_database
from posthog.hogql.database.s3_table import S3Table
from posthog.hogql.errors import ImpossibleASTError, InternalHogQLError, QueryError, ResolutionError
from posthog.hogql.escape_sql import (
Expand Down Expand Up @@ -66,9 +66,7 @@ def team_id_guard_for_table(table_type: Union[ast.TableType, ast.TableAliasType]
)


def to_printed_hogql(
query: ast.Expr, team: Team, modifiers: Optional[HogQLQueryModifiers] = None, database: Optional["Database"] = None
) -> str:
def to_printed_hogql(query: ast.Expr, team: Team, modifiers: Optional[HogQLQueryModifiers] = None) -> str:
"""Prints the HogQL query without mutating the node"""
return print_ast(
clone_expr(query),
Expand All @@ -77,7 +75,6 @@ def to_printed_hogql(
team_id=team.pk,
enable_select_queries=True,
modifiers=create_default_modifiers_for_team(team, modifiers),
database=database,
),
pretty=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
from django.conf import settings
from posthog.constants import ExperimentNoResultsErrorKeys
from posthog.hogql import ast
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.database import create_hogql_database
from posthog.hogql.database.models import LazyJoin
from posthog.hogql_queries.experiments import CONTROL_VARIANT_KEY
from posthog.hogql_queries.experiments.trends_statistics import (
are_results_significant,
Expand Down Expand Up @@ -37,7 +34,7 @@
TrendsQuery,
TrendsQueryResponse,
)
from typing import Any, Optional, cast
from typing import Any, Optional
import threading


Expand Down Expand Up @@ -255,86 +252,7 @@ def calculate(self) -> ExperimentTrendsQueryResponse:

def run(query_runner: TrendsQueryRunner, result_key: str, is_parallel: bool):
try:
# Create a new database instance where we can attach our
# custom join to the events table. It will be passed through
# and used by the query runner.
database = create_hogql_database(team_id=self.team.pk)
if self._is_data_warehouse_query(query_runner.query):
series_node = cast(DataWarehouseNode, query_runner.query.series[0])
table = database.get_table(series_node.table_name)
table.fields["events"] = LazyJoin(
from_field=[series_node.distinct_id_field],
join_table=database.get_table("events"),
join_function=lambda join_to_add, context, node: (
ast.JoinExpr(
table=ast.SelectQuery(
select=[
ast.Alias(alias=name, expr=ast.Field(chain=["events", *chain]))
for name, chain in {
**join_to_add.fields_accessed,
"timestamp": ["timestamp"],
"distinct_id": ["distinct_id"],
"properties": ["properties"],
}.items()
],
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
),
# ASOF JOIN finds the most recent matching event that occurred at or before each data warehouse timestamp.
#
# Why this matters:
# When a user performs an action (recorded in data warehouse), we want to know which
# experiment variant they were assigned at that moment. The most recent $feature_flag_called
# event before their action represents their active variant assignment.
#
# Example:
# Data Warehouse: timestamp=2024-01-03 12:00, distinct_id=user1
# Events:
# 2024-01-02: (user1, variant='control') <- This event will be joined
# 2024-01-03: (user1, variant='test') <- Ignored (occurs after data warehouse timestamp)
#
# This ensures we capture the correct causal relationship: which experiment variant
# was the user assigned to when they performed the action?
join_type="ASOF LEFT JOIN",
alias=join_to_add.to_table,
constraint=ast.JoinConstraint(
expr=ast.And(
exprs=[
ast.CompareOperation(
left=ast.Field(chain=[join_to_add.to_table, "event"]),
op=ast.CompareOperationOp.Eq,
right=ast.Constant(value="$feature_flag_called"),
),
ast.CompareOperation(
left=ast.Field(
chain=[
join_to_add.from_table,
series_node.distinct_id_field,
]
),
op=ast.CompareOperationOp.Eq,
right=ast.Field(chain=[join_to_add.to_table, "distinct_id"]),
),
ast.CompareOperation(
left=ast.Field(
chain=[
join_to_add.from_table,
series_node.timestamp_field,
]
),
op=ast.CompareOperationOp.GtEq,
right=ast.Field(chain=[join_to_add.to_table, "timestamp"]),
),
]
),
constraint_type="ON",
),
)
),
)

context = HogQLContext(team_id=self.team.pk, database=database)

result = query_runner.calculate(context=context)
result = query_runner.calculate()
shared_results[result_key] = result
except Exception as e:
errors.append(e)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from django.test import override_settings
from posthog.hogql.errors import QueryError
from posthog.hogql_queries.experiments.experiment_trends_query_runner import ExperimentTrendsQueryRunner
from posthog.models.experiment import Experiment, ExperimentHoldout
from posthog.models.feature_flag.feature_flag import FeatureFlag
Expand Down Expand Up @@ -34,6 +33,7 @@
from boto3 import resource
from botocore.config import Config
from posthog.warehouse.models.credential import DataWarehouseCredential
from posthog.warehouse.models.join import DataWarehouseJoin
from posthog.warehouse.models.table import DataWarehouseTable

TEST_BUCKET = "test_storage_bucket-posthog.hogql.datawarehouse.trendquery" + XDIST_SUFFIX
Expand Down Expand Up @@ -137,7 +137,7 @@ def create_data_warehouse_table_with_payments(self):
)
distinct_id = pa.array(["user_control_0", "user_test_1", "user_test_2", "user_test_3", "user_extra"])
amount = pa.array([100, 50, 75, 80, 90])
names = ["id", "timestamp", "distinct_id", "amount"]
names = ["id", "dw_timestamp", "dw_distinct_id", "amount"]

pq.write_to_dataset(
pa.Table.from_arrays([id, timestamp, distinct_id, amount], names=names),
Expand All @@ -163,12 +163,22 @@ def create_data_warehouse_table_with_payments(self):
team=self.team,
columns={
"id": "String",
"timestamp": "DateTime64(3, 'UTC')",
"distinct_id": "String",
"dw_timestamp": "DateTime64(3, 'UTC')",
"dw_distinct_id": "String",
"amount": "Int64",
},
credential=credential,
)

DataWarehouseJoin.objects.create(
team=self.team,
source_table_name=table_name,
source_table_key="dw_distinct_id",
joining_table_name="events",
joining_table_key="distinct_id",
field_name="events",
configuration={"experiments_optimized": True, "experiments_timestamp_key": "dw_timestamp"},
)
return table_name

@freeze_time("2020-01-01T12:00:00Z")
Expand Down Expand Up @@ -494,10 +504,10 @@ def test_query_runner_with_data_warehouse_series(self):
series=[
DataWarehouseNode(
id=table_name,
distinct_id_field="distinct_id",
id_field="distinct_id",
distinct_id_field="dw_distinct_id",
id_field="id",
table_name=table_name,
timestamp_field="timestamp",
timestamp_field="dw_timestamp",
)
]
)
Expand Down Expand Up @@ -587,10 +597,10 @@ def test_query_runner_with_invalid_data_warehouse_table_name(self):
series=[
DataWarehouseNode(
id=table_name,
distinct_id_field="distinct_id",
id_field="distinct_id",
distinct_id_field="dw_distinct_id",
id_field="id",
table_name=table_name,
timestamp_field="timestamp",
timestamp_field="dw_timestamp",
)
]
)
Expand All @@ -610,10 +620,10 @@ def test_query_runner_with_invalid_data_warehouse_table_name(self):
query=ExperimentTrendsQuery(**experiment.metrics[0]["query"]), team=self.team
)
with freeze_time("2023-01-07"):
with self.assertRaises(QueryError) as context:
with self.assertRaises(KeyError) as context:
query_runner.calculate()

self.assertEqual(str(context.exception), 'Unknown table "invalid_table_name".')
self.assertEqual(str(context.exception), "'invalid_table_name'")

@freeze_time("2020-01-01T12:00:00Z")
def test_query_runner_with_avg_math(self):
Expand Down
Loading
Loading