diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json
index 7dd0764bc9cb2..35d5d5e6949eb 100644
--- a/frontend/src/queries/schema.json
+++ b/frontend/src/queries/schema.json
@@ -1235,6 +1235,10 @@
"additionalProperties": false,
"description": "HogQL Query Options are automatically set per team. However, they can be overriden in the query.",
"properties": {
+ "inCohortVia": {
+ "enum": ["leftjoin", "subquery"],
+ "type": "string"
+ },
"personsArgMaxVersion": {
"enum": ["auto", "v1", "v2"],
"type": "string"
diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts
index 215aad0f7a0c2..e16ccceb9b0db 100644
--- a/frontend/src/queries/schema.ts
+++ b/frontend/src/queries/schema.ts
@@ -133,6 +133,7 @@ export interface DataNode extends Node {
export interface HogQLQueryModifiers {
personsOnEventsMode?: 'disabled' | 'v1_enabled' | 'v2_enabled'
personsArgMaxVersion?: 'auto' | 'v1' | 'v2'
+ inCohortVia?: 'leftjoin' | 'subquery'
}
export interface HogQLQueryResponse {
diff --git a/frontend/src/scenes/debug/HogQLDebug.tsx b/frontend/src/scenes/debug/HogQLDebug.tsx
index d143a2196327f..d1d32f98643b1 100644
--- a/frontend/src/scenes/debug/HogQLDebug.tsx
+++ b/frontend/src/scenes/debug/HogQLDebug.tsx
@@ -32,7 +32,7 @@ export function HogQLDebug({ query, setQuery, queryKey }: HogQLDebugProps): JSX.
- POE Version:
+ POE:
- Persons ArgMax Version
+ Persons ArgMax:
+
+ In Cohort Via:
+
+ setQuery({
+ ...query,
+ modifiers: { ...query.modifiers, inCohortVia: value },
+ } as HogQLQuery)
+ }
+ value={query.modifiers?.inCohortVia ?? response?.modifiers?.inCohortVia}
+ />
+ {' '}
{dataLoading ? (
<>
diff --git a/posthog/hogql/database/schema/cohort_people.py b/posthog/hogql/database/schema/cohort_people.py
index 7aa94704e2c96..097e74856f410 100644
--- a/posthog/hogql/database/schema/cohort_people.py
+++ b/posthog/hogql/database/schema/cohort_people.py
@@ -51,7 +51,7 @@ def to_printed_clickhouse(self, context):
return "cohortpeople"
def to_printed_hogql(self):
- return "cohort_people"
+ return "raw_cohort_people"
class CohortPeople(LazyTable):
diff --git a/posthog/hogql/database/schema/groups.py b/posthog/hogql/database/schema/groups.py
index 0619bf1b5ad3d..9b3fc1f28c176 100644
--- a/posthog/hogql/database/schema/groups.py
+++ b/posthog/hogql/database/schema/groups.py
@@ -69,7 +69,7 @@ def to_printed_clickhouse(self, context):
return "groups"
def to_printed_hogql(self):
- return "groups"
+ return "raw_groups"
class GroupsTable(LazyTable):
diff --git a/posthog/hogql/database/schema/person_overrides.py b/posthog/hogql/database/schema/person_overrides.py
index c4576d0a58b83..9e2e92656867c 100644
--- a/posthog/hogql/database/schema/person_overrides.py
+++ b/posthog/hogql/database/schema/person_overrides.py
@@ -62,7 +62,7 @@ def to_printed_clickhouse(self, context):
return "person_overrides"
def to_printed_hogql(self):
- return "person_overrides"
+ return "raw_person_overrides"
class PersonOverridesTable(Table):
diff --git a/posthog/hogql/functions/test/test_cohort.py b/posthog/hogql/functions/test/test_cohort.py
index fad1bed1dfc86..c9adaffbba8a0 100644
--- a/posthog/hogql/functions/test/test_cohort.py
+++ b/posthog/hogql/functions/test/test_cohort.py
@@ -7,6 +7,7 @@
from posthog.models import Cohort
from posthog.models.cohort.util import recalculate_cohortpeople
from posthog.models.utils import UUIDT
+from posthog.schema import HogQLQueryModifiers
from posthog.test.base import BaseTest, _create_person, _create_event, flush_persons_and_events
elements_chain_match = lambda x: parse_expr("match(elements_chain, {regex})", {"regex": ast.Constant(value=str(x))})
@@ -38,6 +39,7 @@ def test_in_cohort_dynamic(self):
response = execute_hogql_query(
f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk} AND event='{random_uuid}'",
self.team,
+ modifiers=HogQLQueryModifiers(inCohortVia="subquery"),
)
self.assertEqual(
response.clickhouse,
@@ -45,7 +47,7 @@ def test_in_cohort_dynamic(self):
)
self.assertEqual(
response.hogql,
- f"SELECT event FROM events WHERE and(in(person_id, (SELECT person_id FROM cohort_people WHERE equals(cohort_id, {cohort.pk}) GROUP BY person_id, cohort_id, version HAVING greater(sum(sign), 0))), equals(event, '{random_uuid}')) LIMIT 100",
+ f"SELECT event FROM events WHERE and(in(person_id, (SELECT person_id FROM raw_cohort_people WHERE equals(cohort_id, {cohort.pk}) GROUP BY person_id, cohort_id, version HAVING greater(sum(sign), 0))), equals(event, '{random_uuid}')) LIMIT 100",
)
self.assertEqual(len(response.results), 1)
self.assertEqual(response.results[0][0], random_uuid)
@@ -59,6 +61,7 @@ def test_in_cohort_static(self):
response = execute_hogql_query(
f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk}",
self.team,
+ modifiers=HogQLQueryModifiers(inCohortVia="subquery"),
)
self.assertEqual(
response.clickhouse,
@@ -79,6 +82,7 @@ def test_in_cohort_strings(self):
response = execute_hogql_query(
f"SELECT event FROM events WHERE person_id IN COHORT 'my cohort'",
self.team,
+ modifiers=HogQLQueryModifiers(inCohortVia="subquery"),
)
self.assertEqual(
response.clickhouse,
diff --git a/posthog/hogql/modifiers.py b/posthog/hogql/modifiers.py
index 3f3cd86b5f8f0..0643deefcc6fa 100644
--- a/posthog/hogql/modifiers.py
+++ b/posthog/hogql/modifiers.py
@@ -19,4 +19,7 @@ def create_default_modifiers_for_team(
if modifiers.personsArgMaxVersion is None:
modifiers.personsArgMaxVersion = "auto"
+ if modifiers.inCohortVia is None:
+ modifiers.inCohortVia = "subquery"
+
return modifiers
diff --git a/posthog/hogql/parser.py b/posthog/hogql/parser.py
index deb5799620937..3624127a9ea62 100644
--- a/posthog/hogql/parser.py
+++ b/posthog/hogql/parser.py
@@ -46,7 +46,7 @@ def parse_expr(
node = RULE_TO_PARSE_FUNCTION[backend]["expr"](expr, start)
if placeholders:
with timings.measure("replace_placeholders"):
- return replace_placeholders(node, placeholders)
+ node = replace_placeholders(node, placeholders)
return node
@@ -63,7 +63,7 @@ def parse_order_expr(
node = RULE_TO_PARSE_FUNCTION[backend]["order_expr"](order_expr)
if placeholders:
with timings.measure("replace_placeholders"):
- return replace_placeholders(node, placeholders)
+ node = replace_placeholders(node, placeholders)
return node
diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py
index 69fe54f6ad913..800b649760285 100644
--- a/posthog/hogql/printer.py
+++ b/posthog/hogql/printer.py
@@ -31,6 +31,7 @@
)
from posthog.hogql.functions.mapping import validate_function_args
from posthog.hogql.resolver import ResolverException, lookup_field_by_name, resolve_types
+from posthog.hogql.transforms.in_cohort import resolve_in_cohorts
from posthog.hogql.transforms.lazy_tables import resolve_lazy_tables
from posthog.hogql.transforms.property_types import resolve_property_types
from posthog.hogql.visitor import Visitor, clone_expr
@@ -83,6 +84,9 @@ def prepare_ast_for_printing(
with context.timings.measure("resolve_types"):
node = resolve_types(node, context, scopes=[node.type for node in stack] if stack else None)
+ if context.modifiers.inCohortVia == "leftjoin":
+ with context.timings.measure("resolve_in_cohorts"):
+ resolve_in_cohorts(node, stack, context)
if dialect == "clickhouse":
with context.timings.measure("resolve_property_types"):
node = resolve_property_types(node, context)
@@ -489,7 +493,7 @@ def visit_compare_operation(self, node: ast.CompareOperation):
lambda left_op, right_op: left_op <= right_op if left_op is not None and right_op is not None else False
)
else:
- raise HogQLException(f"Unknown CompareOperationOp: {type(node.op).__name__}")
+ raise HogQLException(f"Unknown CompareOperationOp: {node.op.name}")
# Try to see if we can take shortcuts
diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py
index 6df62049c8cd1..652229a41d010 100644
--- a/posthog/hogql/resolver.py
+++ b/posthog/hogql/resolver.py
@@ -4,11 +4,10 @@
from posthog.hogql import ast
from posthog.hogql.ast import FieldTraverserType, ConstantType
-from posthog.hogql.functions import HOGQL_POSTHOG_FUNCTIONS
+from posthog.hogql.functions import HOGQL_POSTHOG_FUNCTIONS, cohort
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import StringJSONDatabaseField, FunctionCallTable, LazyTable, SavedQuery
from posthog.hogql.errors import ResolverException
-from posthog.hogql.functions.cohort import cohort
from posthog.hogql.functions.mapping import validate_function_args
from posthog.hogql.functions.sparkline import sparkline
from posthog.hogql.parser import parse_select
@@ -508,22 +507,23 @@ def visit_compare_operation(self, node: ast.CompareOperation):
)
)
- if node.op == ast.CompareOperationOp.InCohort:
- return self.visit(
- ast.CompareOperation(
- op=ast.CompareOperationOp.In,
- left=node.left,
- right=cohort(node=node.right, args=[node.right], context=self.context),
+ if self.context.modifiers.inCohortVia != "leftjoin":
+ if node.op == ast.CompareOperationOp.InCohort:
+ return self.visit(
+ ast.CompareOperation(
+ op=ast.CompareOperationOp.In,
+ left=node.left,
+ right=cohort(node=node.right, args=[node.right], context=self.context),
+ )
)
- )
- elif node.op == ast.CompareOperationOp.NotInCohort:
- return self.visit(
- ast.CompareOperation(
- op=ast.CompareOperationOp.NotIn,
- left=node.left,
- right=cohort(node=node.right, args=[node.right], context=self.context),
+ elif node.op == ast.CompareOperationOp.NotInCohort:
+ return self.visit(
+ ast.CompareOperation(
+ op=ast.CompareOperationOp.NotIn,
+ left=node.left,
+ right=cohort(node=node.right, args=[node.right], context=self.context),
+ )
)
- )
node = super().visit_compare_operation(node)
node.type = ast.BooleanType()
diff --git a/posthog/hogql/test/test_modifiers.py b/posthog/hogql/test/test_modifiers.py
index d6d0f0e64d101..a4a674a801ff7 100644
--- a/posthog/hogql/test/test_modifiers.py
+++ b/posthog/hogql/test/test_modifiers.py
@@ -1,5 +1,6 @@
from posthog.hogql.modifiers import create_default_modifiers_for_team
from posthog.hogql.query import execute_hogql_query
+from posthog.models import Cohort
from posthog.schema import HogQLQueryModifiers
from posthog.test.base import BaseTest
from django.test import override_settings
@@ -71,3 +72,20 @@ def test_modifiers_persons_argmax_version_auto(self):
modifiers=HogQLQueryModifiers(personsArgMaxVersion="auto"),
)
assert "in(tuple(person.id, person.version)" not in response.clickhouse
+
+ def test_modifiers_in_cohort_join(self):
+ cohort = Cohort.objects.create(team=self.team, name="test")
+ response = execute_hogql_query(
+ f"select * from persons where id in cohort {cohort.pk}",
+ team=self.team,
+ modifiers=HogQLQueryModifiers(inCohortVia="subquery"),
+ )
+ assert "LEFT JOIN" not in response.clickhouse
+
+ # Use the v1 query when not selecting any properties
+ response = execute_hogql_query(
+ f"select * from persons where id in cohort {cohort.pk}",
+ team=self.team,
+ modifiers=HogQLQueryModifiers(inCohortVia="leftjoin"),
+ )
+ assert "LEFT JOIN" in response.clickhouse
diff --git a/posthog/hogql/transforms/in_cohort.py b/posthog/hogql/transforms/in_cohort.py
new file mode 100644
index 0000000000000..aa1fe0e3a23ee
--- /dev/null
+++ b/posthog/hogql/transforms/in_cohort.py
@@ -0,0 +1,132 @@
+from typing import List, Optional, cast
+
+from posthog.hogql import ast
+from posthog.hogql.context import HogQLContext
+from posthog.hogql.errors import HogQLException
+from posthog.hogql.escape_sql import escape_clickhouse_string
+from posthog.hogql.parser import parse_expr
+from posthog.hogql.resolver import resolve_types
+from posthog.hogql.visitor import TraversingVisitor, clone_expr
+
+
+def resolve_in_cohorts(node: ast.Expr, stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None):
+ InCohortResolver(stack=stack, context=context).visit(node)
+
+
+class InCohortResolver(TraversingVisitor):
+ def __init__(self, stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None):
+ super().__init__()
+ self.stack: List[ast.SelectQuery] = stack or []
+ self.context = context
+
+ def visit_select_query(self, node: ast.SelectQuery):
+ self.stack.append(node)
+ super().visit_select_query(node)
+ self.stack.pop()
+
+ def visit_compare_operation(self, node: ast.CompareOperation):
+ if node.op == ast.CompareOperationOp.InCohort or node.op == ast.CompareOperationOp.NotInCohort:
+ arg = node.right
+ if not isinstance(arg, ast.Constant):
+ raise HogQLException("IN COHORT only works with constant arguments", node=arg)
+
+ from posthog.models import Cohort
+
+ if isinstance(arg.value, int) and not isinstance(arg.value, bool):
+ cohorts = Cohort.objects.filter(id=arg.value, team_id=self.context.team_id).values_list(
+ "id", "is_static", "name"
+ )
+ if len(cohorts) == 1:
+ self.context.add_notice(
+ start=arg.start,
+ end=arg.end,
+ message=f"Cohort #{cohorts[0][0]} can also be specified as {escape_clickhouse_string(cohorts[0][2])}",
+ fix=escape_clickhouse_string(cohorts[0][2]),
+ )
+ self._add_join_for_cohort(
+ cohort_id=cohorts[0][0],
+ is_static=cohorts[0][1],
+ compare=node,
+ select=self.stack[-1],
+ negative=node.op == ast.CompareOperationOp.NotInCohort,
+ )
+ return
+ raise HogQLException(f"Could not find cohort with id {arg.value}", node=arg)
+
+ if isinstance(arg.value, str):
+ cohorts = Cohort.objects.filter(name=arg.value, team_id=self.context.team_id).values_list(
+ "id", "is_static"
+ )
+ if len(cohorts) == 1:
+ self.context.add_notice(
+ start=arg.start,
+ end=arg.end,
+ message=f"Searching for cohort by name. Replace with numeric ID {cohorts[0][0]} to protect against renaming.",
+ fix=str(cohorts[0][0]),
+ )
+ self._add_join_for_cohort(
+ cohort_id=cohorts[0][0],
+ is_static=cohorts[0][1],
+ compare=node,
+ select=self.stack[-1],
+ negative=node.op == ast.CompareOperationOp.NotInCohort,
+ )
+ return
+ elif len(cohorts) > 1:
+ raise HogQLException(f"Found multiple cohorts with name '{arg.value}'", node=arg)
+ raise HogQLException(f"Could not find a cohort with the name '{arg.value}'", node=arg)
+ else:
+ self.visit(node.left)
+ self.visit(node.right)
+
+ def _add_join_for_cohort(
+ self, cohort_id: int, is_static: bool, select: ast.SelectQuery, compare: ast.CompareOperation, negative: bool
+ ):
+ must_add_join = True
+ last_join = select.select_from
+ while last_join:
+ if isinstance(last_join.table, ast.Field) and last_join.table.chain[0] == f"in_cohort__{cohort_id}":
+ must_add_join = False
+ break
+ if last_join.next_join:
+ last_join = last_join.next_join
+ else:
+ break
+
+ if must_add_join:
+ if is_static:
+ sql = "(SELECT person_id, 1 as matched FROM static_cohort_people WHERE cohort_id = {cohort_id})"
+ else:
+ sql = "(SELECT person_id, 1 as matched FROM raw_cohort_people WHERE cohort_id = {cohort_id} GROUP BY person_id, cohort_id, version HAVING sum(sign) > 0)"
+ subquery = parse_expr(
+ sql, {"cohort_id": ast.Constant(value=cohort_id)}, start=None
+ ) # clear the source start position
+
+ new_join = ast.JoinExpr(
+ alias=f"in_cohort__{cohort_id}",
+ table=subquery,
+ join_type="LEFT JOIN",
+ next_join=None,
+ constraint=ast.JoinConstraint(
+ expr=ast.CompareOperation(
+ op=ast.CompareOperationOp.Eq,
+ left=ast.Constant(value=1),
+ right=ast.Constant(value=1),
+ )
+ ),
+ )
+ new_join = cast(ast.JoinExpr, resolve_types(new_join, self.context, [self.stack[-1].type]))
+ new_join.constraint.expr.left = resolve_types(
+ ast.Field(chain=[f"in_cohort__{cohort_id}", "person_id"]), self.context, [self.stack[-1].type]
+ )
+ new_join.constraint.expr.right = clone_expr(compare.left)
+ if last_join:
+ last_join.next_join = new_join
+ else:
+ select.select_from = new_join
+
+ compare.op = ast.CompareOperationOp.NotEq if negative else ast.CompareOperationOp.Eq
+ compare.left = resolve_types(
+ ast.Field(chain=[f"in_cohort__{cohort_id}", "matched"]), self.context, [self.stack[-1].type]
+ )
+ compare.right = resolve_types(ast.Constant(value=1), self.context, [self.stack[-1].type])
diff --git a/posthog/hogql/transforms/test/test_in_cohort.py b/posthog/hogql/transforms/test/test_in_cohort.py
new file mode 100644
index 0000000000000..dbef0b685aadf
--- /dev/null
+++ b/posthog/hogql/transforms/test/test_in_cohort.py
@@ -0,0 +1,104 @@
+from django.test import override_settings
+
+from posthog.hogql import ast
+from posthog.hogql.errors import HogQLException
+from posthog.hogql.parser import parse_expr
+from posthog.hogql.query import execute_hogql_query
+from posthog.models import Cohort
+from posthog.models.cohort.util import recalculate_cohortpeople
+from posthog.models.utils import UUIDT
+from posthog.schema import HogQLQueryModifiers
+from posthog.test.base import BaseTest, _create_person, _create_event, flush_persons_and_events
+
+elements_chain_match = lambda x: parse_expr("match(elements_chain, {regex})", {"regex": ast.Constant(value=str(x))})
+not_call = lambda x: ast.Call(name="not", args=[x])
+
+
+class TestInCohort(BaseTest):
+ maxDiff = None
+
+ def _create_random_events(self) -> str:
+ random_uuid = str(UUIDT())
+ _create_person(
+ properties={"$os": "Chrome", "random_uuid": random_uuid},
+ team=self.team,
+ distinct_ids=["bla"],
+ is_identified=True,
+ )
+ _create_event(distinct_id="bla", event=random_uuid, team=self.team)
+ flush_persons_and_events()
+ return random_uuid
+
+ @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=True)
+ def test_in_cohort_dynamic(self):
+ random_uuid = self._create_random_events()
+ cohort = Cohort.objects.create(
+ team=self.team, groups=[{"properties": [{"key": "$os", "value": "Chrome", "type": "person"}]}]
+ )
+ recalculate_cohortpeople(cohort, pending_version=0)
+ response = execute_hogql_query(
+ f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk} AND event='{random_uuid}'",
+ self.team,
+ modifiers=HogQLQueryModifiers(inCohortVia="leftjoin"),
+ )
+ self.assertEqual(
+ response.clickhouse,
+ f"SELECT events.event FROM events LEFT JOIN (SELECT cohortpeople.person_id, 1 AS matched FROM cohortpeople WHERE and(equals(cohortpeople.team_id, {self.team.pk}), equals(cohortpeople.cohort_id, {cohort.pk})) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0)) AS in_cohort__{cohort.pk} ON equals(in_cohort__{cohort.pk}.person_id, events.person_id) WHERE and(equals(events.team_id, {self.team.pk}), ifNull(equals(in_cohort__{cohort.pk}.matched, 1), 0), equals(events.event, %(hogql_val_0)s)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1",
+ )
+ self.assertEqual(
+ response.hogql,
+ f"SELECT event FROM events LEFT JOIN (SELECT person_id, 1 AS matched FROM raw_cohort_people WHERE equals(cohort_id, {cohort.pk}) GROUP BY person_id, cohort_id, version HAVING greater(sum(sign), 0)) AS in_cohort__{cohort.pk} ON equals(in_cohort__{cohort.pk}.person_id, person_id) WHERE and(equals(in_cohort__{cohort.pk}.matched, 1), equals(event, '{random_uuid}')) LIMIT 100",
+ )
+ self.assertEqual(len(response.results), 1)
+ self.assertEqual(response.results[0][0], random_uuid)
+
+ @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=True)
+ def test_in_cohort_static(self):
+ cohort = Cohort.objects.create(
+ team=self.team,
+ is_static=True,
+ )
+ response = execute_hogql_query(
+ f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk}",
+ self.team,
+ modifiers=HogQLQueryModifiers(inCohortVia="leftjoin"),
+ )
+ self.assertEqual(
+ response.clickhouse,
+ f"SELECT events.event FROM events LEFT JOIN (SELECT person_static_cohort.person_id, 1 AS matched FROM person_static_cohort WHERE and(equals(person_static_cohort.team_id, {self.team.pk}), equals(person_static_cohort.cohort_id, {cohort.pk}))) AS in_cohort__{cohort.pk} ON equals(in_cohort__{cohort.pk}.person_id, events.person_id) WHERE and(equals(events.team_id, {self.team.pk}), ifNull(equals(in_cohort__{cohort.pk}.matched, 1), 0)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1",
+ )
+ self.assertEqual(
+ response.hogql,
+ f"SELECT event FROM events LEFT JOIN (SELECT person_id, 1 AS matched FROM static_cohort_people WHERE equals(cohort_id, {cohort.pk})) AS in_cohort__{cohort.pk} ON equals(in_cohort__{cohort.pk}.person_id, person_id) WHERE equals(in_cohort__{cohort.pk}.matched, 1) LIMIT 100",
+ )
+
+ @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=True)
+ def test_in_cohort_strings(self):
+ cohort = Cohort.objects.create(
+ team=self.team,
+ name="my cohort",
+ is_static=True,
+ )
+ response = execute_hogql_query(
+ f"SELECT event FROM events WHERE person_id IN COHORT 'my cohort'",
+ self.team,
+ modifiers=HogQLQueryModifiers(inCohortVia="leftjoin"),
+ )
+ self.assertEqual(
+ response.clickhouse,
+ f"SELECT events.event FROM events LEFT JOIN (SELECT person_static_cohort.person_id, 1 AS matched FROM person_static_cohort WHERE and(equals(person_static_cohort.team_id, {self.team.pk}), equals(person_static_cohort.cohort_id, {cohort.pk}))) AS in_cohort__{cohort.pk} ON equals(in_cohort__{cohort.pk}.person_id, events.person_id) WHERE and(equals(events.team_id, {self.team.pk}), ifNull(equals(in_cohort__{cohort.pk}.matched, 1), 0)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1",
+ )
+ self.assertEqual(
+ response.hogql,
+ f"SELECT event FROM events LEFT JOIN (SELECT person_id, 1 AS matched FROM static_cohort_people WHERE equals(cohort_id, {cohort.pk})) AS in_cohort__{cohort.pk} ON equals(in_cohort__{cohort.pk}.person_id, person_id) WHERE equals(in_cohort__{cohort.pk}.matched, 1) LIMIT 100",
+ )
+
+ @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=True)
+ def test_in_cohort_error(self):
+ with self.assertRaises(HogQLException) as e:
+ execute_hogql_query(f"SELECT event FROM events WHERE person_id IN COHORT true", self.team)
+ self.assertEqual(str(e.exception), "cohort() takes exactly one string or integer argument")
+
+ with self.assertRaises(HogQLException) as e:
+ execute_hogql_query(f"SELECT event FROM events WHERE person_id IN COHORT 'blabla'", self.team)
+ self.assertEqual(str(e.exception), "Could not find a cohort with the name 'blabla'")
diff --git a/posthog/schema.py b/posthog/schema.py
index 6b49b7a7363ea..9d08bd77c7d3c 100644
--- a/posthog/schema.py
+++ b/posthog/schema.py
@@ -228,6 +228,11 @@ class HogQLNotice(BaseModel):
start: Optional[float] = None
+class InCohortVia(str, Enum):
+ leftjoin = "leftjoin"
+ subquery = "subquery"
+
+
class PersonsArgMaxVersion(str, Enum):
auto = "auto"
v1 = "v1"
@@ -244,6 +249,7 @@ class HogQLQueryModifiers(BaseModel):
model_config = ConfigDict(
extra="forbid",
)
+ inCohortVia: Optional[InCohortVia] = None
personsArgMaxVersion: Optional[PersonsArgMaxVersion] = None
personsOnEventsMode: Optional[PersonsOnEventsMode] = None