Skip to content

Commit

Permalink
feat(experiments HogQL rewrite): prepare trend queries (#25177)
Browse files Browse the repository at this point in the history
  • Loading branch information
jurajmajerik authored Sep 30, 2024
1 parent 6448c6e commit fafd152
Show file tree
Hide file tree
Showing 5 changed files with 456 additions and 65 deletions.
6 changes: 3 additions & 3 deletions frontend/src/queries/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -4618,13 +4618,13 @@
"ExperimentTrendQuery": {
"additionalProperties": false,
"properties": {
"count_source": {
"count_query": {
"$ref": "#/definitions/TrendsQuery"
},
"experiment_id": {
"type": "integer"
},
"exposure_source": {
"exposure_query": {
"$ref": "#/definitions/TrendsQuery"
},
"kind": {
Expand All @@ -4639,7 +4639,7 @@
"$ref": "#/definitions/ExperimentTrendQueryResponse"
}
},
"required": ["count_source", "experiment_id", "exposure_source", "kind"],
"required": ["count_query", "experiment_id", "kind"],
"type": "object"
},
"ExperimentTrendQueryResponse": {
Expand Down
6 changes: 4 additions & 2 deletions frontend/src/queries/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1602,8 +1602,10 @@ export interface ExperimentFunnelQuery extends DataNode<ExperimentFunnelQueryRes

export interface ExperimentTrendQuery extends DataNode<ExperimentTrendQueryResponse> {
kind: NodeKind.ExperimentTrendQuery
count_source: TrendsQuery
exposure_source: TrendsQuery
count_query: TrendsQuery
// Defaults to $feature_flag_called if not specified
// https://github.com/PostHog/posthog/blob/master/posthog/hogql_queries/experiments/experiment_trend_query_runner.py
exposure_query?: TrendsQuery
experiment_id: integer
}

Expand Down
199 changes: 182 additions & 17 deletions posthog/hogql_queries/experiments/experiment_trend_query_runner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from zoneinfo import ZoneInfo
from django.conf import settings
from posthog.hogql import ast
from posthog.hogql_queries.insights.trends.trends_query_runner import TrendsQueryRunner
from posthog.hogql_queries.query_runner import QueryRunner
from posthog.models.experiment import Experiment
from posthog.queries.trends.util import ALL_SUPPORTED_MATH_FUNCTIONS
from posthog.schema import (
BaseMathType,
BreakdownFilter,
ChartDisplayType,
EventPropertyFilter,
EventsNode,
ExperimentTrendQuery,
ExperimentTrendQueryResponse,
ExperimentVariantTrendResult,
InsightDateRange,
PropertyMathType,
TrendsFilter,
TrendsQuery,
)
from typing import Any
from typing import Any, Optional
import threading


Expand All @@ -19,22 +30,176 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.experiment = Experiment.objects.get(id=self.query.experiment_id)
self.feature_flag = self.experiment.feature_flag
self.breakdown_key = f"$feature/{self.feature_flag.key}"

self.query_runner = TrendsQueryRunner(
query=self.query.count_source, team=self.team, timings=self.timings, limit_context=self.limit_context
self.prepared_count_query = self._prepare_count_query()
self.prepared_exposure_query = self._prepare_exposure_query()

self.count_query_runner = TrendsQueryRunner(
query=self.prepared_count_query, team=self.team, timings=self.timings, limit_context=self.limit_context
)
self.exposure_query_runner = TrendsQueryRunner(
query=self.query.exposure_source, team=self.team, timings=self.timings, limit_context=self.limit_context
query=self.prepared_exposure_query, team=self.team, timings=self.timings, limit_context=self.limit_context
)

def _uses_math_aggregation_by_user_or_property_value(self, query: TrendsQuery):
math_keys = ALL_SUPPORTED_MATH_FUNCTIONS
# "sum" doesn't need special handling, we *can* have custom exposure for sum filters
if "sum" in math_keys:
math_keys.remove("sum")
return any(entity.math in math_keys for entity in query.series)

def _get_insight_date_range(self) -> InsightDateRange:
"""
Returns an InsightDateRange object based on the experiment's start and end dates,
adjusted for the team's timezone if applicable.
"""
if self.team.timezone:
tz = ZoneInfo(self.team.timezone)
start_date = self.experiment.start_date.astimezone(tz) if self.experiment.start_date else None
end_date = self.experiment.end_date.astimezone(tz) if self.experiment.end_date else None
else:
start_date = self.experiment.start_date
end_date = self.experiment.end_date

return InsightDateRange(
date_from=start_date.isoformat() if start_date else None,
date_to=end_date.isoformat() if end_date else None,
explicitDate=True,
)

def _get_breakdown_filter(self) -> BreakdownFilter:
return BreakdownFilter(
breakdown=self.breakdown_key,
breakdown_type="event",
)

def _prepare_count_query(self) -> TrendsQuery:
"""
This method takes the raw trend query and adapts it
for the needs of experiment analysis:
1. Set the trend display type based on whether math aggregation is used
2. Set the date range to match the experiment's duration, using the project's timezone.
3. Configure the breakdown to use the feature flag key, which allows us
to separate results for different experiment variants.
"""
prepared_count_query = TrendsQuery(**self.query.count_query.model_dump())

uses_math_aggregation = self._uses_math_aggregation_by_user_or_property_value(prepared_count_query)

# :TRICKY: for `avg` aggregation, use `sum` data as an approximation
if prepared_count_query.series[0].math == PropertyMathType.AVG:
prepared_count_query.series[0].math = PropertyMathType.SUM
prepared_count_query.trendsFilter = TrendsFilter(display=ChartDisplayType.ACTIONS_LINE_GRAPH_CUMULATIVE)
# TODO: revisit this; using the count data for the remaining aggregation types is likely wrong
elif uses_math_aggregation:
prepared_count_query.series[0].math = None
prepared_count_query.trendsFilter = TrendsFilter(display=ChartDisplayType.ACTIONS_LINE_GRAPH_CUMULATIVE)

prepared_count_query.dateRange = self._get_insight_date_range()
prepared_count_query.breakdownFilter = self._get_breakdown_filter()
prepared_count_query.properties = [
EventPropertyFilter(
key=self.breakdown_key,
value=[variant["key"] for variant in self.feature_flag.variants],
operator="exact",
type="event",
)
]

return prepared_count_query

def _prepare_exposure_query(self) -> TrendsQuery:
"""
This method prepares the exposure query for the experiment analysis.
Exposure is the count of users who have seen the experiment. This is necessary to calculate the statistical
significance of the experiment.
There are 3 possible cases for the exposure query:
1. If math aggregation is used, we construct an implicit exposure query
2. Otherwise, if an exposure query is provided, we use it as is, adapting it to the experiment's duration and breakdown
3. Otherwise, we construct a default exposure query (the count of $feature_flag_called events)
"""

# 1. If math aggregation is used, we construct an implicit exposure query: unique users for the count event
uses_math_aggregation = self._uses_math_aggregation_by_user_or_property_value(self.query.count_query)

if uses_math_aggregation:
prepared_exposure_query = TrendsQuery(**self.query.count_query.model_dump())
count_event = self.query.count_query.series[0]

if hasattr(count_event, "event"):
prepared_exposure_query.dateRange = self._get_insight_date_range()
prepared_exposure_query.breakdownFilter = self._get_breakdown_filter()
prepared_exposure_query.series = [
EventsNode(
event=count_event.event,
math=BaseMathType.DAU,
)
]
prepared_exposure_query.properties = [
EventPropertyFilter(
key=self.breakdown_key,
value=[variant["key"] for variant in self.feature_flag.variants],
operator="exact",
type="event",
)
]
else:
raise ValueError("Expected first series item to have an 'event' attribute")

# 2. Otherwise, if an exposure query is provided, we use it as is, adapting the date range and breakdown
elif self.query.exposure_query:
prepared_exposure_query = TrendsQuery(**self.query.exposure_query.model_dump())
prepared_exposure_query.dateRange = self._get_insight_date_range()
prepared_exposure_query.breakdownFilter = self._get_breakdown_filter()
prepared_exposure_query.properties = [
EventPropertyFilter(
key=self.breakdown_key,
value=[variant["key"] for variant in self.feature_flag.variants],
operator="exact",
type="event",
)
]
# 3. Otherwise, we construct a default exposure query: unique users for the $feature_flag_called event
else:
prepared_exposure_query = TrendsQuery(
dateRange=self._get_insight_date_range(),
breakdownFilter=self._get_breakdown_filter(),
series=[
EventsNode(
event="$feature_flag_called",
math=BaseMathType.DAU, # TODO sync with frontend!!!
)
],
properties=[
EventPropertyFilter(
key=self.breakdown_key,
value=[variant["key"] for variant in self.feature_flag.variants],
operator="exact",
type="event",
),
EventPropertyFilter(
key="$feature_flag",
value=[self.feature_flag.key],
operator="exact",
type="event",
),
],
)

return prepared_exposure_query

def calculate(self) -> ExperimentTrendQueryResponse:
count_response = None
exposure_response = None
shared_results: dict[str, Optional[Any]] = {"count_response": None, "exposure_response": None}
errors = []

def run(query_runner: TrendsQueryRunner, is_parallel: bool):
def run(query_runner: TrendsQueryRunner, result_key: str, is_parallel: bool):
try:
return query_runner.calculate()
result = query_runner.calculate()
shared_results[result_key] = result
except Exception as e:
errors.append(e)
finally:
Expand All @@ -46,23 +211,23 @@ def run(query_runner: TrendsQueryRunner, is_parallel: bool):

# This exists so that we're not spawning threads during unit tests
if settings.IN_UNIT_TESTING:
count_response = run(self.query_runner, False)
exposure_response = run(self.exposure_query_runner, False)
run(self.count_query_runner, "count_response", False)
run(self.exposure_query_runner, "exposure_response", False)
else:
jobs = [
threading.Thread(target=run, args=(self.query_runner, True)),
threading.Thread(target=run, args=(self.exposure_query_runner, True)),
threading.Thread(target=run, args=(self.count_query_runner, "count_response", True)),
threading.Thread(target=run, args=(self.exposure_query_runner, "exposure_response", True)),
]
[j.start() for j in jobs] # type: ignore
[j.join() for j in jobs] # type: ignore

count_response = getattr(jobs[0], "result", None)
exposure_response = getattr(jobs[1], "result", None)

# Raise any errors raised in a separate thread
if len(errors) > 0:
if errors:
raise errors[0]

count_response = shared_results["count_response"]
exposure_response = shared_results["exposure_response"]

if count_response is None or exposure_response is None:
raise ValueError("One or both query runners failed to produce a response")

Expand All @@ -88,4 +253,4 @@ def _process_results(
return processed_results

def to_query(self) -> ast.SelectQuery:
raise ValueError(f"Cannot convert source query of type {self.query.count_source.kind} to query")
raise ValueError(f"Cannot convert source query of type {self.query.count_query.kind} to query")
Loading

0 comments on commit fafd152

Please sign in to comment.