Skip to content

Commit

Permalink
feat(hogql): use join for "in cohort" operations instead of subquery (P…
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra authored and Justicea83 committed Oct 25, 2023
1 parent 3f9638b commit 73e2f0a
Show file tree
Hide file tree
Showing 15 changed files with 317 additions and 25 deletions.
4 changes: 4 additions & 0 deletions frontend/src/queries/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,10 @@
"additionalProperties": false,
"description": "HogQL Query Options are automatically set per team. However, they can be overriden in the query.",
"properties": {
"inCohortVia": {
"enum": ["leftjoin", "subquery"],
"type": "string"
},
"personsArgMaxVersion": {
"enum": ["auto", "v1", "v2"],
"type": "string"
Expand Down
1 change: 1 addition & 0 deletions frontend/src/queries/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ export interface DataNode extends Node {
export interface HogQLQueryModifiers {
personsOnEventsMode?: 'disabled' | 'v1_enabled' | 'v2_enabled'
personsArgMaxVersion?: 'auto' | 'v1' | 'v2'
inCohortVia?: 'leftjoin' | 'subquery'
}

export interface HogQLQueryResponse {
Expand Down
20 changes: 18 additions & 2 deletions frontend/src/scenes/debug/HogQLDebug.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export function HogQLDebug({ query, setQuery, queryKey }: HogQLDebugProps): JSX.
</div>
<div className="flex gap-2">
<LemonLabel>
POE Version:
POE:
<LemonSelect
options={[
{ value: 'disabled', label: 'Disabled' },
Expand All @@ -49,7 +49,7 @@ export function HogQLDebug({ query, setQuery, queryKey }: HogQLDebugProps): JSX.
/>
</LemonLabel>
<LemonLabel>
Persons ArgMax Version
Persons ArgMax:
<LemonSelect
options={[
{ value: 'v1', label: 'V1' },
Expand All @@ -64,6 +64,22 @@ export function HogQLDebug({ query, setQuery, queryKey }: HogQLDebugProps): JSX.
value={query.modifiers?.personsArgMaxVersion ?? response?.modifiers?.personsArgMaxVersion}
/>
</LemonLabel>
<LemonLabel>
In Cohort Via:
<LemonSelect
options={[
{ value: 'join', label: 'join' },
{ value: 'subquery', label: 'subquery' },
]}
onChange={(value) =>
setQuery({
...query,
modifiers: { ...query.modifiers, inCohortVia: value },
} as HogQLQuery)
}
value={query.modifiers?.inCohortVia ?? response?.modifiers?.inCohortVia}
/>
</LemonLabel>{' '}
</div>
{dataLoading ? (
<>
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql/database/schema/cohort_people.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def to_printed_clickhouse(self, context):
return "cohortpeople"

def to_printed_hogql(self):
return "cohort_people"
return "raw_cohort_people"


class CohortPeople(LazyTable):
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql/database/schema/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def to_printed_clickhouse(self, context):
return "groups"

def to_printed_hogql(self):
return "groups"
return "raw_groups"


class GroupsTable(LazyTable):
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql/database/schema/person_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def to_printed_clickhouse(self, context):
return "person_overrides"

def to_printed_hogql(self):
return "person_overrides"
return "raw_person_overrides"


class PersonOverridesTable(Table):
Expand Down
6 changes: 5 additions & 1 deletion posthog/hogql/functions/test/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from posthog.models import Cohort
from posthog.models.cohort.util import recalculate_cohortpeople
from posthog.models.utils import UUIDT
from posthog.schema import HogQLQueryModifiers
from posthog.test.base import BaseTest, _create_person, _create_event, flush_persons_and_events

elements_chain_match = lambda x: parse_expr("match(elements_chain, {regex})", {"regex": ast.Constant(value=str(x))})
Expand Down Expand Up @@ -38,14 +39,15 @@ def test_in_cohort_dynamic(self):
response = execute_hogql_query(
f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk} AND event='{random_uuid}'",
self.team,
modifiers=HogQLQueryModifiers(inCohortVia="subquery"),
)
self.assertEqual(
response.clickhouse,
f"SELECT events.event FROM events WHERE and(equals(events.team_id, {self.team.pk}), in(events.person_id, (SELECT cohortpeople.person_id FROM cohortpeople WHERE and(equals(cohortpeople.team_id, {self.team.pk}), equals(cohortpeople.cohort_id, {cohort.pk})) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), equals(events.event, %(hogql_val_0)s)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1",
)
self.assertEqual(
response.hogql,
f"SELECT event FROM events WHERE and(in(person_id, (SELECT person_id FROM cohort_people WHERE equals(cohort_id, {cohort.pk}) GROUP BY person_id, cohort_id, version HAVING greater(sum(sign), 0))), equals(event, '{random_uuid}')) LIMIT 100",
f"SELECT event FROM events WHERE and(in(person_id, (SELECT person_id FROM raw_cohort_people WHERE equals(cohort_id, {cohort.pk}) GROUP BY person_id, cohort_id, version HAVING greater(sum(sign), 0))), equals(event, '{random_uuid}')) LIMIT 100",
)
self.assertEqual(len(response.results), 1)
self.assertEqual(response.results[0][0], random_uuid)
Expand All @@ -59,6 +61,7 @@ def test_in_cohort_static(self):
response = execute_hogql_query(
f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk}",
self.team,
modifiers=HogQLQueryModifiers(inCohortVia="subquery"),
)
self.assertEqual(
response.clickhouse,
Expand All @@ -79,6 +82,7 @@ def test_in_cohort_strings(self):
response = execute_hogql_query(
f"SELECT event FROM events WHERE person_id IN COHORT 'my cohort'",
self.team,
modifiers=HogQLQueryModifiers(inCohortVia="subquery"),
)
self.assertEqual(
response.clickhouse,
Expand Down
3 changes: 3 additions & 0 deletions posthog/hogql/modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ def create_default_modifiers_for_team(
if modifiers.personsArgMaxVersion is None:
modifiers.personsArgMaxVersion = "auto"

if modifiers.inCohortVia is None:
modifiers.inCohortVia = "subquery"

return modifiers
4 changes: 2 additions & 2 deletions posthog/hogql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def parse_expr(
node = RULE_TO_PARSE_FUNCTION[backend]["expr"](expr, start)
if placeholders:
with timings.measure("replace_placeholders"):
return replace_placeholders(node, placeholders)
node = replace_placeholders(node, placeholders)
return node


Expand All @@ -63,7 +63,7 @@ def parse_order_expr(
node = RULE_TO_PARSE_FUNCTION[backend]["order_expr"](order_expr)
if placeholders:
with timings.measure("replace_placeholders"):
return replace_placeholders(node, placeholders)
node = replace_placeholders(node, placeholders)
return node


Expand Down
6 changes: 5 additions & 1 deletion posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from posthog.hogql.functions.mapping import validate_function_args
from posthog.hogql.resolver import ResolverException, lookup_field_by_name, resolve_types
from posthog.hogql.transforms.in_cohort import resolve_in_cohorts
from posthog.hogql.transforms.lazy_tables import resolve_lazy_tables
from posthog.hogql.transforms.property_types import resolve_property_types
from posthog.hogql.visitor import Visitor, clone_expr
Expand Down Expand Up @@ -83,6 +84,9 @@ def prepare_ast_for_printing(

with context.timings.measure("resolve_types"):
node = resolve_types(node, context, scopes=[node.type for node in stack] if stack else None)
if context.modifiers.inCohortVia == "leftjoin":
with context.timings.measure("resolve_in_cohorts"):
resolve_in_cohorts(node, stack, context)
if dialect == "clickhouse":
with context.timings.measure("resolve_property_types"):
node = resolve_property_types(node, context)
Expand Down Expand Up @@ -489,7 +493,7 @@ def visit_compare_operation(self, node: ast.CompareOperation):
lambda left_op, right_op: left_op <= right_op if left_op is not None and right_op is not None else False
)
else:
raise HogQLException(f"Unknown CompareOperationOp: {type(node.op).__name__}")
raise HogQLException(f"Unknown CompareOperationOp: {node.op.name}")

# Try to see if we can take shortcuts

Expand Down
32 changes: 16 additions & 16 deletions posthog/hogql/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

from posthog.hogql import ast
from posthog.hogql.ast import FieldTraverserType, ConstantType
from posthog.hogql.functions import HOGQL_POSTHOG_FUNCTIONS
from posthog.hogql.functions import HOGQL_POSTHOG_FUNCTIONS, cohort
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import StringJSONDatabaseField, FunctionCallTable, LazyTable, SavedQuery
from posthog.hogql.errors import ResolverException
from posthog.hogql.functions.cohort import cohort
from posthog.hogql.functions.mapping import validate_function_args
from posthog.hogql.functions.sparkline import sparkline
from posthog.hogql.parser import parse_select
Expand Down Expand Up @@ -508,22 +507,23 @@ def visit_compare_operation(self, node: ast.CompareOperation):
)
)

if node.op == ast.CompareOperationOp.InCohort:
return self.visit(
ast.CompareOperation(
op=ast.CompareOperationOp.In,
left=node.left,
right=cohort(node=node.right, args=[node.right], context=self.context),
if self.context.modifiers.inCohortVia != "leftjoin":
if node.op == ast.CompareOperationOp.InCohort:
return self.visit(
ast.CompareOperation(
op=ast.CompareOperationOp.In,
left=node.left,
right=cohort(node=node.right, args=[node.right], context=self.context),
)
)
)
elif node.op == ast.CompareOperationOp.NotInCohort:
return self.visit(
ast.CompareOperation(
op=ast.CompareOperationOp.NotIn,
left=node.left,
right=cohort(node=node.right, args=[node.right], context=self.context),
elif node.op == ast.CompareOperationOp.NotInCohort:
return self.visit(
ast.CompareOperation(
op=ast.CompareOperationOp.NotIn,
left=node.left,
right=cohort(node=node.right, args=[node.right], context=self.context),
)
)
)

node = super().visit_compare_operation(node)
node.type = ast.BooleanType()
Expand Down
18 changes: 18 additions & 0 deletions posthog/hogql/test/test_modifiers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from posthog.hogql.modifiers import create_default_modifiers_for_team
from posthog.hogql.query import execute_hogql_query
from posthog.models import Cohort
from posthog.schema import HogQLQueryModifiers
from posthog.test.base import BaseTest
from django.test import override_settings
Expand Down Expand Up @@ -71,3 +72,20 @@ def test_modifiers_persons_argmax_version_auto(self):
modifiers=HogQLQueryModifiers(personsArgMaxVersion="auto"),
)
assert "in(tuple(person.id, person.version)" not in response.clickhouse

def test_modifiers_in_cohort_join(self):
cohort = Cohort.objects.create(team=self.team, name="test")
response = execute_hogql_query(
f"select * from persons where id in cohort {cohort.pk}",
team=self.team,
modifiers=HogQLQueryModifiers(inCohortVia="subquery"),
)
assert "LEFT JOIN" not in response.clickhouse

# Use the v1 query when not selecting any properties
response = execute_hogql_query(
f"select * from persons where id in cohort {cohort.pk}",
team=self.team,
modifiers=HogQLQueryModifiers(inCohortVia="leftjoin"),
)
assert "LEFT JOIN" in response.clickhouse
132 changes: 132 additions & 0 deletions posthog/hogql/transforms/in_cohort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import List, Optional, cast

from posthog.hogql import ast
from posthog.hogql.context import HogQLContext
from posthog.hogql.errors import HogQLException
from posthog.hogql.escape_sql import escape_clickhouse_string
from posthog.hogql.parser import parse_expr
from posthog.hogql.resolver import resolve_types
from posthog.hogql.visitor import TraversingVisitor, clone_expr


def resolve_in_cohorts(node: ast.Expr, stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None):
InCohortResolver(stack=stack, context=context).visit(node)


class InCohortResolver(TraversingVisitor):
def __init__(self, stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None):
super().__init__()
self.stack: List[ast.SelectQuery] = stack or []
self.context = context

def visit_select_query(self, node: ast.SelectQuery):
self.stack.append(node)
super().visit_select_query(node)
self.stack.pop()

def visit_compare_operation(self, node: ast.CompareOperation):
if node.op == ast.CompareOperationOp.InCohort or node.op == ast.CompareOperationOp.NotInCohort:
arg = node.right
if not isinstance(arg, ast.Constant):
raise HogQLException("IN COHORT only works with constant arguments", node=arg)

from posthog.models import Cohort

if isinstance(arg.value, int) and not isinstance(arg.value, bool):
cohorts = Cohort.objects.filter(id=arg.value, team_id=self.context.team_id).values_list(
"id", "is_static", "name"
)
if len(cohorts) == 1:
self.context.add_notice(
start=arg.start,
end=arg.end,
message=f"Cohort #{cohorts[0][0]} can also be specified as {escape_clickhouse_string(cohorts[0][2])}",
fix=escape_clickhouse_string(cohorts[0][2]),
)
self._add_join_for_cohort(
cohort_id=cohorts[0][0],
is_static=cohorts[0][1],
compare=node,
select=self.stack[-1],
negative=node.op == ast.CompareOperationOp.NotInCohort,
)
return
raise HogQLException(f"Could not find cohort with id {arg.value}", node=arg)

if isinstance(arg.value, str):
cohorts = Cohort.objects.filter(name=arg.value, team_id=self.context.team_id).values_list(
"id", "is_static"
)
if len(cohorts) == 1:
self.context.add_notice(
start=arg.start,
end=arg.end,
message=f"Searching for cohort by name. Replace with numeric ID {cohorts[0][0]} to protect against renaming.",
fix=str(cohorts[0][0]),
)
self._add_join_for_cohort(
cohort_id=cohorts[0][0],
is_static=cohorts[0][1],
compare=node,
select=self.stack[-1],
negative=node.op == ast.CompareOperationOp.NotInCohort,
)
return
elif len(cohorts) > 1:
raise HogQLException(f"Found multiple cohorts with name '{arg.value}'", node=arg)
raise HogQLException(f"Could not find a cohort with the name '{arg.value}'", node=arg)
else:
self.visit(node.left)
self.visit(node.right)

def _add_join_for_cohort(
self, cohort_id: int, is_static: bool, select: ast.SelectQuery, compare: ast.CompareOperation, negative: bool
):
must_add_join = True
last_join = select.select_from
while last_join:
if isinstance(last_join.table, ast.Field) and last_join.table.chain[0] == f"in_cohort__{cohort_id}":
must_add_join = False
break
if last_join.next_join:
last_join = last_join.next_join
else:
break

if must_add_join:
if is_static:
sql = "(SELECT person_id, 1 as matched FROM static_cohort_people WHERE cohort_id = {cohort_id})"
else:
sql = "(SELECT person_id, 1 as matched FROM raw_cohort_people WHERE cohort_id = {cohort_id} GROUP BY person_id, cohort_id, version HAVING sum(sign) > 0)"
subquery = parse_expr(
sql, {"cohort_id": ast.Constant(value=cohort_id)}, start=None
) # clear the source start position

new_join = ast.JoinExpr(
alias=f"in_cohort__{cohort_id}",
table=subquery,
join_type="LEFT JOIN",
next_join=None,
constraint=ast.JoinConstraint(
expr=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Constant(value=1),
right=ast.Constant(value=1),
)
),
)
new_join = cast(ast.JoinExpr, resolve_types(new_join, self.context, [self.stack[-1].type]))
new_join.constraint.expr.left = resolve_types(
ast.Field(chain=[f"in_cohort__{cohort_id}", "person_id"]), self.context, [self.stack[-1].type]
)
new_join.constraint.expr.right = clone_expr(compare.left)
if last_join:
last_join.next_join = new_join
else:
select.select_from = new_join

compare.op = ast.CompareOperationOp.NotEq if negative else ast.CompareOperationOp.Eq
compare.left = resolve_types(
ast.Field(chain=[f"in_cohort__{cohort_id}", "matched"]), self.context, [self.stack[-1].type]
)
compare.right = resolve_types(ast.Constant(value=1), self.context, [self.stack[-1].type])
Loading

0 comments on commit 73e2f0a

Please sign in to comment.