diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index e41bf1a843171..40df13d525209 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -1709,7 +1709,7 @@ "description": "HogQL Query Options are automatically set per team. However, they can be overriden in the query.", "properties": { "inCohortVia": { - "enum": ["leftjoin", "subquery"], + "enum": ["leftjoin", "subquery", "leftjoin_conjoined"], "type": "string" }, "materializationMode": { diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index d769a559e7394..d705cc839240e 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -161,7 +161,7 @@ export interface DataNode extends Node { export interface HogQLQueryModifiers { personsOnEventsMode?: 'disabled' | 'v1_enabled' | 'v1_mixed' | 'v2_enabled' personsArgMaxVersion?: 'auto' | 'v1' | 'v2' - inCohortVia?: 'leftjoin' | 'subquery' + inCohortVia?: 'leftjoin' | 'subquery' | 'leftjoin_conjoined' materializationMode?: 'auto' | 'legacy_null_as_string' | 'legacy_null_as_null' | 'disabled' } diff --git a/mypy-baseline.txt b/mypy-baseline.txt index aa9cccb3e809a..7747af86b3112 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -163,7 +163,6 @@ posthog/hogql/transforms/property_types.py:0: error: Statement is unreachable [ posthog/hogql/transforms/property_types.py:0: error: Argument 2 to "_get_materialized_column" of "PropertySwapper" has incompatible type "str | int"; expected "str" [arg-type] posthog/hogql/modifiers.py:0: error: Incompatible types in assignment (expression has type "PersonOnEventsMode", variable has type "PersonsOnEventsMode | None") [assignment] posthog/hogql/modifiers.py:0: error: Incompatible types in assignment (expression has type "str", variable has type "PersonsArgMaxVersion | None") [assignment] -posthog/hogql/modifiers.py:0: error: Incompatible types in assignment (expression has type "str", variable has type "InCohortVia | None") [assignment] posthog/hogql/functions/cohort.py:0: error: Argument 1 to "escape_clickhouse_string" has incompatible type "str | None"; expected "float | int | str | list[Any] | tuple[Any, ...] | date | datetime | UUID | UUIDT" [arg-type] posthog/hogql/functions/cohort.py:0: error: Argument 1 to "escape_clickhouse_string" has incompatible type "str | None"; expected "float | int | str | list[Any] | tuple[Any, ...] | date | datetime | UUID | UUIDT" [arg-type] posthog/hogql/functions/cohort.py:0: error: Incompatible types in assignment (expression has type "ValuesQuerySet[Cohort, tuple[int, bool | None]]", variable has type "ValuesQuerySet[Cohort, tuple[int, bool | None, str | None]]") [assignment] @@ -366,11 +365,6 @@ posthog/hogql/query.py:0: error: Argument 1 to "get_default_limit_for_context" h posthog/hogql/query.py:0: error: "SelectQuery" has no attribute "select_queries" [attr-defined] posthog/hogql/query.py:0: error: Subclass of "SelectQuery" and "SelectUnionQuery" cannot exist: would have incompatible method signatures [unreachable] posthog/hogql_queries/query_runner.py:0: error: Incompatible types in assignment (expression has type "HogQLQuery | TrendsQuery | LifecycleQuery | InsightActorsQuery | EventsQuery | ActorsQuery | RetentionQuery | SessionsTimelineQuery | WebOverviewQuery | WebTopClicksQuery | WebStatsTableQuery | StickinessQuery | BaseModel | dict[str, Any]", variable has type "HogQLQuery | TrendsQuery | LifecycleQuery | InsightActorsQuery | EventsQuery | ActorsQuery | RetentionQuery | SessionsTimelineQuery | WebOverviewQuery | WebTopClicksQuery | WebStatsTableQuery | StickinessQuery") [assignment] -posthog/hogql_queries/insights/trends/breakdown_values.py:0: error: Argument "chain" to "Field" has incompatible type "list[str]"; expected "list[str | int]" [arg-type] -posthog/hogql_queries/insights/trends/breakdown_values.py:0: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance -posthog/hogql_queries/insights/trends/breakdown_values.py:0: note: Consider using "Sequence" instead, which is covariant -posthog/hogql_queries/insights/trends/breakdown_values.py:0: error: Argument "breakdown_type" to "get_properties_chain" has incompatible type "str"; expected "Literal['person', 'session', 'group', 'event']" [arg-type] -posthog/hogql_queries/insights/trends/breakdown_values.py:0: error: Argument "breakdown_field" to "get_properties_chain" has incompatible type "str | float"; expected "str" [arg-type] posthog/hogql_queries/insights/trends/breakdown_values.py:0: error: Incompatible types in assignment (expression has type "float | int", variable has type "int") [assignment] posthog/hogql_queries/insights/trends/breakdown_values.py:0: error: Item "SelectUnionQuery" of "SelectQuery | SelectUnionQuery" has no attribute "select" [union-attr] posthog/hogql_queries/insights/trends/breakdown_values.py:0: error: Value of type "list[Any] | None" is not indexable [index] @@ -380,19 +374,11 @@ posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "Bre posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_type" [union-attr] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument 1 to "parse_expr" has incompatible type "str | float | list[str | float] | Any | None"; expected "str" [arg-type] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument 1 to "int" has incompatible type "str | float | list[str | float] | Any | None"; expected "str | Buffer | SupportsInt | SupportsIndex | SupportsTrunc" [arg-type] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_type" [union-attr] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument 1 to "parse_expr" has incompatible type "str | float | list[str | float] | Any | None"; expected "str" [arg-type] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Statement is unreachable [unreachable] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_type" [union-attr] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument 1 to "int" has incompatible type "str | float | list[str | float] | Any | None"; expected "str | Buffer | SupportsInt | SupportsIndex | SupportsTrunc" [arg-type] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_type" [union-attr] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument 1 to "parse_expr" has incompatible type "str | float | list[str | float] | Any | None"; expected "str" [arg-type] +posthog/hogql_queries/insights/trends/breakdown.py:0: error: Statement is unreachable [unreachable] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument "exprs" to "Or" has incompatible type "list[CompareOperation]"; expected "list[Expr]" [arg-type] posthog/hogql_queries/insights/trends/breakdown.py:0: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance posthog/hogql_queries/insights/trends/breakdown.py:0: note: Consider using "Sequence" instead, which is covariant @@ -400,9 +386,6 @@ posthog/hogql_queries/insights/trends/breakdown.py:0: error: Incompatible types posthog/hogql_queries/insights/trends/breakdown.py:0: error: Incompatible types in assignment (expression has type "float", variable has type "int") [assignment] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Incompatible types in assignment (expression has type "str", variable has type "int") [assignment] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Incompatible types in assignment (expression has type "str", variable has type "int") [assignment] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument "event_name" to "BreakdownValues" has incompatible type "str | None"; expected "str" [arg-type] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument "breakdown_field" to "BreakdownValues" has incompatible type "str | float | list[str | float] | Any | None"; expected "str | float" [arg-type] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_type" [union-attr] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument "breakdown_type" to "BreakdownValues" has incompatible type "BreakdownType | Any | None"; expected "str" [arg-type] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_histogram_bin_count" [union-attr] @@ -417,7 +400,6 @@ posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument "breakdown posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Argument "breakdown_field" to "get_properties_chain" has incompatible type "str | float | list[str | float] | Any | None"; expected "str" [arg-type] posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_group_type_index" [union-attr] -posthog/hogql_queries/insights/trends/breakdown.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_type" [union-attr] posthog/hogql_queries/hogql_query_runner.py:0: error: Statement is unreachable [unreachable] posthog/hogql_queries/hogql_query_runner.py:0: error: Argument "placeholders" to "parse_select" has incompatible type "dict[str, Constant] | None"; expected "dict[str, Expr] | None" [arg-type] posthog/hogql_queries/hogql_query_runner.py:0: error: Incompatible types in assignment (expression has type "Expr", variable has type "SelectQuery | SelectUnionQuery") [assignment] @@ -439,13 +421,9 @@ posthog/api/person.py:0: error: Argument 1 to "loads" has incompatible type "str posthog/api/person.py:0: error: Argument "user" to "log_activity" has incompatible type "User | AnonymousUser"; expected "User | None" [arg-type] posthog/api/person.py:0: error: Argument "user" to "log_activity" has incompatible type "User | AnonymousUser"; expected "User | None" [arg-type] posthog/hogql_queries/web_analytics/web_analytics_query_runner.py:0: error: Argument 1 to "append" of "list" has incompatible type "EventPropertyFilter"; expected "Expr" [arg-type] -posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Return type "list[SelectQuery]" of "to_query" incompatible with return type "SelectQuery | SelectUnionQuery" in supertype "QueryRunner" [override] -posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Incompatible return value type (got "list[SelectQuery | SelectUnionQuery]", expected "list[SelectQuery]") [return-value] posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Need type annotation for "timings" (hint: "timings: List[] = ...") [var-annotated] posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Argument 1 to "extend" of "list" has incompatible type "list[QueryTiming] | None"; expected "Iterable[Any]" [arg-type] posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Statement is unreachable [unreachable] -posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: List item 0 has incompatible type "str | float | None"; expected "str | float" [list-item] -posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown" [union-attr] posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Argument 1 to "FormulaAST" has incompatible type "map[Any]"; expected "list[list[float]]" [arg-type] posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Argument 1 to "FormulaAST" has incompatible type "map[Any]"; expected "list[list[float]]" [arg-type] posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Item "None" of "BreakdownFilter | None" has no attribute "breakdown_type" [union-attr] @@ -634,11 +612,6 @@ posthog/hogql_queries/insights/test/test_events_query_runner.py:0: error: Item " posthog/hogql_queries/insights/test/test_events_query_runner.py:0: error: Item "None" of "Expr | None" has no attribute "exprs" [union-attr] posthog/hogql_queries/insights/test/test_events_query_runner.py:0: error: Item "Expr" of "Expr | None" has no attribute "exprs" [union-attr] posthog/hogql_queries/insights/test/test_events_query_runner.py:0: error: Item "None" of "Expr | None" has no attribute "exprs" [union-attr] -posthog/hogql/transforms/test/test_in_cohort.py:0: error: "TestInCohort" has no attribute "snapshot" [attr-defined] -posthog/hogql/transforms/test/test_in_cohort.py:0: error: Argument 1 to "len" has incompatible type "list[Any] | None"; expected "Sized" [arg-type] -posthog/hogql/transforms/test/test_in_cohort.py:0: error: Value of type "list[Any] | None" is not indexable [index] -posthog/hogql/transforms/test/test_in_cohort.py:0: error: "TestInCohort" has no attribute "snapshot" [attr-defined] -posthog/hogql/transforms/test/test_in_cohort.py:0: error: "TestInCohort" has no attribute "snapshot" [attr-defined] posthog/hogql/test/test_timings.py:0: error: No overload variant of "__setitem__" of "list" matches argument types "int", "float" [call-overload] posthog/hogql/test/test_timings.py:0: note: Possible overload variants: posthog/hogql/test/test_timings.py:0: note: def __setitem__(self, SupportsIndex, int, /) -> None diff --git a/posthog/hogql/functions/cohort.py b/posthog/hogql/functions/cohort.py index eb9c71993a0e5..07e701bac6391 100644 --- a/posthog/hogql/functions/cohort.py +++ b/posthog/hogql/functions/cohort.py @@ -26,8 +26,10 @@ def cohort(node: ast.Expr, args: List[ast.Expr], context: HogQLContext) -> ast.E from posthog.models import Cohort - if isinstance(arg.value, int) and not isinstance(arg.value, bool): - cohorts = Cohort.objects.filter(id=arg.value, team_id=context.team_id).values_list("id", "is_static", "name") + if (isinstance(arg.value, int) or isinstance(arg.value, float)) and not isinstance(arg.value, bool): + cohorts = Cohort.objects.filter(id=int(arg.value), team_id=context.team_id).values_list( + "id", "is_static", "name" + ) if len(cohorts) == 1: context.add_notice( start=arg.start, diff --git a/posthog/hogql/modifiers.py b/posthog/hogql/modifiers.py index fd49ba2bc270c..20b451d978fd0 100644 --- a/posthog/hogql/modifiers.py +++ b/posthog/hogql/modifiers.py @@ -1,6 +1,6 @@ from typing import Optional, TYPE_CHECKING -from posthog.schema import HogQLQueryModifiers, MaterializationMode +from posthog.schema import HogQLQueryModifiers, InCohortVia, MaterializationMode from posthog.utils import PersonOnEventsMode if TYPE_CHECKING: @@ -22,7 +22,7 @@ def create_default_modifiers_for_team( modifiers.personsArgMaxVersion = "auto" if modifiers.inCohortVia is None: - modifiers.inCohortVia = "subquery" + modifiers.inCohortVia = InCohortVia.subquery if modifiers.materializationMode is None or modifiers.materializationMode == MaterializationMode.auto: modifiers.materializationMode = MaterializationMode.legacy_null_as_null diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index 1e2c574846e35..428ab73a70811 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -33,7 +33,7 @@ from posthog.hogql.modifiers import create_default_modifiers_for_team from posthog.hogql.resolver import ResolverException, resolve_types from posthog.hogql.resolver_utils import lookup_field_by_name -from posthog.hogql.transforms.in_cohort import resolve_in_cohorts +from posthog.hogql.transforms.in_cohort import resolve_in_cohorts, resolve_in_cohorts_conjoined from posthog.hogql.transforms.lazy_tables import resolve_lazy_tables from posthog.hogql.transforms.property_types import resolve_property_types from posthog.hogql.visitor import Visitor, clone_expr @@ -41,7 +41,7 @@ from posthog.models.team.team import WeekStartDay from posthog.models.team import Team from posthog.models.utils import UUIDT -from posthog.schema import MaterializationMode +from posthog.schema import InCohortVia, MaterializationMode from posthog.utils import PersonOnEventsMode @@ -99,9 +99,12 @@ def prepare_ast_for_printing( with context.timings.measure("create_hogql_database"): context.database = context.database or create_hogql_database(context.team_id, context.modifiers) + if context.modifiers.inCohortVia == InCohortVia.leftjoin_conjoined: + with context.timings.measure("resolve_in_cohorts_conjoined"): + resolve_in_cohorts_conjoined(node, dialect, context, stack) with context.timings.measure("resolve_types"): node = resolve_types(node, context, dialect=dialect, scopes=[node.type for node in stack] if stack else None) - if context.modifiers.inCohortVia == "leftjoin": + if context.modifiers.inCohortVia == InCohortVia.leftjoin: with context.timings.measure("resolve_in_cohorts"): resolve_in_cohorts(node, dialect, stack, context) if dialect == "clickhouse": diff --git a/posthog/hogql/transforms/in_cohort.py b/posthog/hogql/transforms/in_cohort.py index a565391e309f3..3d1a075212282 100644 --- a/posthog/hogql/transforms/in_cohort.py +++ b/posthog/hogql/transforms/in_cohort.py @@ -1,10 +1,11 @@ -from typing import List, Optional, cast, Literal +from typing import List, Optional, Tuple, cast, Literal + from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.errors import HogQLException from posthog.hogql.escape_sql import escape_clickhouse_string -from posthog.hogql.parser import parse_expr +from posthog.hogql.parser import parse_expr, parse_select from posthog.hogql.resolver import resolve_types from posthog.hogql.visitor import TraversingVisitor, clone_expr @@ -18,6 +19,247 @@ def resolve_in_cohorts( InCohortResolver(stack=stack, dialect=dialect, context=context).visit(node) +def resolve_in_cohorts_conjoined( + node: ast.Expr, + dialect: Literal["hogql", "clickhouse"], + context: HogQLContext, + stack: Optional[List[ast.SelectQuery]] = None, +): + MultipleInCohortResolver(stack=stack, dialect=dialect, context=context).visit(node) + + +class CohortCompareOperationTraverser(TraversingVisitor): + ops: List[ast.CompareOperation] = [] + + def __init__(self, expr: ast.Expr): + self.ops = [] + super().visit(expr) + + def visit_compare_operation(self, node: ast.CompareOperation): + if node.op == ast.CompareOperationOp.InCohort or node.op == ast.CompareOperationOp.NotInCohort: + self.ops.append(node) + + +StaticOrDynamic = Literal["dynamic"] | Literal["static"] + + +class MultipleInCohortResolver(TraversingVisitor): + dialect: Literal["hogql", "clickhouse"] + + def __init__( + self, + dialect: Literal["hogql", "clickhouse"], + context: HogQLContext, + stack: Optional[List[ast.SelectQuery]] = None, + ): + super().__init__() + self.stack: List[ast.SelectQuery] = stack or [] + self.context = context + self.dialect = dialect + + def visit_select_query(self, node: ast.SelectQuery): + self.stack.append(node) + + super().visit_select_query(node) + + if node.where is not None: + compare_operations = CohortCompareOperationTraverser(node.where).ops + self._execute(node, compare_operations) + + self.stack.pop() + + def _execute(self, node: ast.SelectQuery, compare_operations: List[ast.CompareOperation]): + if len(compare_operations) == 0: + return + + cohorts = self._resolve_cohorts(compare_operations) + self._add_join(cohorts=cohorts, select=node, compare_operations=compare_operations) + + for compare_node in compare_operations: + compare_node.op = ast.CompareOperationOp.Eq + compare_node.left = ast.Constant(value=1) + compare_node.right = ast.Constant(value=1) + + def _resolve_cohorts( + self, compare_operations: List[ast.CompareOperation] + ) -> List[Tuple[int, StaticOrDynamic, int]]: + from posthog.models import Cohort + + cohorts: List[Tuple[int, StaticOrDynamic, int]] = [] + + for node in compare_operations: + arg = node.right + if not isinstance(arg, ast.Constant): + raise HogQLException("IN COHORT only works with constant arguments", node=arg) + + if (isinstance(arg.value, int) or isinstance(arg.value, float)) and not isinstance(arg.value, bool): + int_cohorts = Cohort.objects.filter(id=int(arg.value), team_id=self.context.team_id).values_list( + "id", "is_static", "version" + ) + if len(int_cohorts) == 1: + if node.op == ast.CompareOperationOp.NotInCohort: + raise HogQLException("NOT IN COHORT is not supported by this cohort mode") + + id = int_cohorts[0][0] + is_static = int_cohorts[0][1] + version = int_cohorts[0][2] or 0 + + cohorts.append((id, "static" if is_static else "dynamic", version)) + continue + raise HogQLException(f"Could not find cohort with id {arg.value}", node=arg) + + if isinstance(arg.value, str): + str_cohorts = Cohort.objects.filter(name=arg.value, team_id=self.context.team_id).values_list( + "id", "is_static", "version" + ) + if len(str_cohorts) == 1: + if node.op == ast.CompareOperationOp.NotInCohort: + raise HogQLException("NOT IN COHORT is not supported by this cohort mode") + + id = str_cohorts[0][0] + is_static = str_cohorts[0][1] + version = str_cohorts[0][2] or 0 + + cohorts.append((id, "static" if is_static else "dynamic", version)) + continue + elif len(str_cohorts) > 1: + raise HogQLException(f"Found multiple cohorts with name '{arg.value}'", node=arg) + raise HogQLException(f"Could not find a cohort with the name '{arg.value}'", node=arg) + + raise HogQLException("cohort() takes exactly one string or integer argument", node=arg) + + return cohorts + + def _add_join( + self, + cohorts: List[Tuple[int, StaticOrDynamic, int]], + select: ast.SelectQuery, + compare_operations: List[ast.CompareOperation], + ): + must_add_join = True + last_join = select.select_from + + while last_join: + if isinstance(last_join.table, ast.Field) and last_join.table.chain[0] == "__in_cohort": + must_add_join = False + break + if last_join.next_join: + last_join = last_join.next_join + else: + break + + if must_add_join: + static_cohorts = list(filter(lambda cohort: cohort[1] == "static", cohorts)) + dynamic_cohorts = list(filter(lambda cohort: cohort[1] == "dynamic", cohorts)) + + any_static = len(static_cohorts) > 0 + any_dynamic = len(dynamic_cohorts) > 0 + + def get_static_cohort_clause(): + return ast.CompareOperation( + left=ast.Field(chain=["cohort_id"]), + op=ast.CompareOperationOp.In, + right=ast.Array(exprs=[ast.Constant(value=id) for id, is_static, version in static_cohorts]), + ) + + def get_dynamic_cohort_clause(): + cohort_or = ast.Or( + exprs=[ + ast.And( + exprs=[ + ast.CompareOperation( + left=ast.Field(chain=["cohort_id"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value=id), + ), + ast.CompareOperation( + left=ast.Field(chain=["version"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value=version), + ), + ] + ) + for id, is_static, version in dynamic_cohorts + ] + ) + + if len(cohort_or.exprs) == 1: + return cohort_or.exprs[0] + + return cohort_or + + # TODO: Extract these `SELECT` clauses out into their own vars and inject + # via placeholders once the HogQL SELECT placeholders functionality is done + if any_static and any_dynamic: + static_clause = get_static_cohort_clause() + dynamic_clause = get_dynamic_cohort_clause() + + table_query = parse_select( + """ + SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id + FROM static_cohort_people + WHERE {static_clause} + UNION ALL + SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id + FROM raw_cohort_people + WHERE {dynamic_clause} + """, + placeholders={"static_clause": static_clause, "dynamic_clause": dynamic_clause}, + ) + elif any_static: + clause = get_static_cohort_clause() + table_query = parse_select( + """ + SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id + FROM static_cohort_people + WHERE {cohort_clause} + """, + placeholders={"cohort_clause": clause}, + ) + else: + clause = get_dynamic_cohort_clause() + table_query = parse_select( + """ + SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id + FROM raw_cohort_people + WHERE {cohort_clause} + """, + placeholders={"cohort_clause": clause}, + ) + + new_join = ast.JoinExpr( + alias=f"__in_cohort", + table=table_query, + join_type="LEFT JOIN", + next_join=None, + constraint=ast.JoinConstraint( + expr=ast.CompareOperation( + op=ast.CompareOperationOp.Eq, + left=ast.Constant(value=1), + right=ast.Constant(value=1), + ) + ), + ) + + new_join.constraint.expr.left = ast.Field(chain=[f"__in_cohort", "cohort_person_id"]) # type: ignore + new_join.constraint.expr.right = clone_expr(compare_operations[0].left) # type: ignore + if last_join: + last_join.next_join = new_join + else: + select.select_from = new_join + + cohort_match_compare_op = ast.CompareOperation( + left=ast.Field(chain=["__in_cohort", "matched"]), + op=ast.CompareOperationOp.Eq, + right=ast.Constant(value=1), + ) + + if select.where is not None: + select.where = ast.And(exprs=[select.where, cohort_match_compare_op]) + else: + select.where = cohort_match_compare_op + + class InCohortResolver(TraversingVisitor): def __init__( self, @@ -43,8 +285,8 @@ def visit_compare_operation(self, node: ast.CompareOperation): from posthog.models import Cohort - if isinstance(arg.value, int) and not isinstance(arg.value, bool): - cohorts = Cohort.objects.filter(id=arg.value, team_id=self.context.team_id).values_list( + if (isinstance(arg.value, int) or isinstance(arg.value, float)) and not isinstance(arg.value, bool): + cohorts = Cohort.objects.filter(id=int(arg.value), team_id=self.context.team_id).values_list( "id", "is_static", "name" ) if len(cohorts) == 1: diff --git a/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr b/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr index 35f094ef7b162..6644b850bb35c 100644 --- a/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr +++ b/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr @@ -1,4 +1,76 @@ # serializer version: 1 +# name: TestInCohort.test_in_cohort_conjoined_dynamic + ''' + -- ClickHouse + + SELECT events.event AS event + FROM events LEFT JOIN ( + SELECT cohortpeople.person_id AS cohort_person_id, 1 AS matched, cohortpeople.cohort_id AS cohort_id + FROM cohortpeople + WHERE and(equals(cohortpeople.team_id, 420), equals(cohortpeople.cohort_id, 1), equals(cohortpeople.version, 0))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) + WHERE and(equals(events.team_id, 420), and(1, equals(events.event, %(hogql_val_0)s)), ifNull(equals(__in_cohort.matched, 1), 0)) + LIMIT 100 + SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 + + -- HogQL + + SELECT event + FROM events LEFT JOIN ( + SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id + FROM raw_cohort_people + WHERE and(equals(cohort_id, 1), equals(version, 0))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) + WHERE and(and(1, equals(event, 'RANDOM_TEST_ID::UUID')), equals(__in_cohort.matched, 1)) + LIMIT 100 + ''' +# --- +# name: TestInCohort.test_in_cohort_conjoined_int + ''' + -- ClickHouse + + SELECT events.event AS event + FROM events LEFT JOIN ( + SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id + FROM person_static_cohort + WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [2]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) + WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0)) + LIMIT 100 + SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 + + -- HogQL + + SELECT event + FROM events LEFT JOIN ( + SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id + FROM static_cohort_people + WHERE in(cohort_id, [2])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) + WHERE and(1, equals(__in_cohort.matched, 1)) + LIMIT 100 + ''' +# --- +# name: TestInCohort.test_in_cohort_conjoined_string + ''' + -- ClickHouse + + SELECT events.event AS event + FROM events LEFT JOIN ( + SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id + FROM person_static_cohort + WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [3]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) + WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0)) + LIMIT 100 + SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 + + -- HogQL + + SELECT event + FROM events LEFT JOIN ( + SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id + FROM static_cohort_people + WHERE in(cohort_id, [3])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) + WHERE and(1, equals(__in_cohort.matched, 1)) + LIMIT 100 + ''' +# --- # name: TestInCohort.test_in_cohort_dynamic ''' -- ClickHouse @@ -7,10 +79,10 @@ FROM events LEFT JOIN ( SELECT cohortpeople.person_id AS person_id, 1 AS matched FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 420), equals(cohortpeople.cohort_id, 1)) + WHERE and(equals(cohortpeople.team_id, 420), equals(cohortpeople.cohort_id, 4)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version - HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0)) AS in_cohort__1 ON equals(in_cohort__1.person_id, events.person_id) - WHERE and(equals(events.team_id, 420), ifNull(equals(in_cohort__1.matched, 1), 0), equals(events.event, %(hogql_val_0)s)) + HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0)) AS in_cohort__4 ON equals(in_cohort__4.person_id, events.person_id) + WHERE and(equals(events.team_id, 420), ifNull(equals(in_cohort__4.matched, 1), 0), equals(events.event, %(hogql_val_0)s)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 @@ -20,10 +92,10 @@ FROM events LEFT JOIN ( SELECT person_id, 1 AS matched FROM raw_cohort_people - WHERE equals(cohort_id, 1) + WHERE equals(cohort_id, 4) GROUP BY person_id, cohort_id, version - HAVING greater(sum(sign), 0)) AS in_cohort__1 ON equals(in_cohort__1.person_id, person_id) - WHERE and(equals(in_cohort__1.matched, 1), equals(event, 'RANDOM_TEST_ID::UUID')) + HAVING greater(sum(sign), 0)) AS in_cohort__4 ON equals(in_cohort__4.person_id, person_id) + WHERE and(equals(in_cohort__4.matched, 1), equals(event, 'RANDOM_TEST_ID::UUID')) LIMIT 100 ''' # --- @@ -35,8 +107,8 @@ FROM events LEFT JOIN ( SELECT person_static_cohort.person_id AS person_id, 1 AS matched FROM person_static_cohort - WHERE and(equals(person_static_cohort.team_id, 420), equals(person_static_cohort.cohort_id, 2))) AS in_cohort__2 ON equals(in_cohort__2.person_id, events.person_id) - WHERE and(equals(events.team_id, 420), ifNull(equals(in_cohort__2.matched, 1), 0)) + WHERE and(equals(person_static_cohort.team_id, 420), equals(person_static_cohort.cohort_id, 5))) AS in_cohort__5 ON equals(in_cohort__5.person_id, events.person_id) + WHERE and(equals(events.team_id, 420), ifNull(equals(in_cohort__5.matched, 1), 0)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 @@ -46,8 +118,8 @@ FROM events LEFT JOIN ( SELECT person_id, 1 AS matched FROM static_cohort_people - WHERE equals(cohort_id, 2)) AS in_cohort__2 ON equals(in_cohort__2.person_id, person_id) - WHERE equals(in_cohort__2.matched, 1) + WHERE equals(cohort_id, 5)) AS in_cohort__5 ON equals(in_cohort__5.person_id, person_id) + WHERE equals(in_cohort__5.matched, 1) LIMIT 100 ''' # --- @@ -59,8 +131,8 @@ FROM events LEFT JOIN ( SELECT person_static_cohort.person_id AS person_id, 1 AS matched FROM person_static_cohort - WHERE and(equals(person_static_cohort.team_id, 420), equals(person_static_cohort.cohort_id, 3))) AS in_cohort__3 ON equals(in_cohort__3.person_id, events.person_id) - WHERE and(equals(events.team_id, 420), ifNull(equals(in_cohort__3.matched, 1), 0)) + WHERE and(equals(person_static_cohort.team_id, 420), equals(person_static_cohort.cohort_id, 6))) AS in_cohort__6 ON equals(in_cohort__6.person_id, events.person_id) + WHERE and(equals(events.team_id, 420), ifNull(equals(in_cohort__6.matched, 1), 0)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 @@ -70,8 +142,8 @@ FROM events LEFT JOIN ( SELECT person_id, 1 AS matched FROM static_cohort_people - WHERE equals(cohort_id, 3)) AS in_cohort__3 ON equals(in_cohort__3.person_id, person_id) - WHERE equals(in_cohort__3.matched, 1) + WHERE equals(cohort_id, 6)) AS in_cohort__6 ON equals(in_cohort__6.person_id, person_id) + WHERE equals(in_cohort__6.matched, 1) LIMIT 100 ''' # --- diff --git a/posthog/hogql/transforms/test/test_in_cohort.py b/posthog/hogql/transforms/test/test_in_cohort.py index 5563ab3eda7e6..2fe6b6cc16c13 100644 --- a/posthog/hogql/transforms/test/test_in_cohort.py +++ b/posthog/hogql/transforms/test/test_in_cohort.py @@ -9,7 +9,7 @@ from posthog.models import Cohort from posthog.models.cohort.util import recalculate_cohortpeople from posthog.models.utils import UUIDT -from posthog.schema import HogQLQueryModifiers +from posthog.schema import HogQLQueryModifiers, InCohortVia from posthog.test.base import ( BaseTest, _create_person, @@ -48,11 +48,11 @@ def test_in_cohort_dynamic(self): response = execute_hogql_query( f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk} AND event='{random_uuid}'", self.team, - modifiers=HogQLQueryModifiers(inCohortVia="leftjoin"), + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.leftjoin), ) - assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot - self.assertEqual(len(response.results), 1) - self.assertEqual(response.results[0][0], random_uuid) + assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot # type: ignore + self.assertEqual(len(response.results or []), 1) + self.assertEqual((response.results or [])[0][0], random_uuid) @pytest.mark.usefixtures("unittest_snapshot") @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False) @@ -64,9 +64,9 @@ def test_in_cohort_static(self): response = execute_hogql_query( f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk}", self.team, - modifiers=HogQLQueryModifiers(inCohortVia="leftjoin"), + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.leftjoin), ) - assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot + assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot # type: ignore @pytest.mark.usefixtures("unittest_snapshot") @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False) @@ -79,20 +79,91 @@ def test_in_cohort_strings(self): response = execute_hogql_query( f"SELECT event FROM events WHERE person_id IN COHORT 'my cohort'", self.team, - modifiers=HogQLQueryModifiers(inCohortVia="leftjoin"), + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.leftjoin), ) - assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot + assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot # type: ignore @pytest.mark.usefixtures("unittest_snapshot") @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=True) def test_in_cohort_error(self): with self.assertRaises(HogQLException) as e: - execute_hogql_query(f"SELECT event FROM events WHERE person_id IN COHORT true", self.team) + execute_hogql_query( + f"SELECT event FROM events WHERE person_id IN COHORT true", + self.team, + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.subquery), + ) + self.assertEqual(str(e.exception), "cohort() takes exactly one string or integer argument") + + with self.assertRaises(HogQLException) as e: + execute_hogql_query( + f"SELECT event FROM events WHERE person_id IN COHORT 'blabla'", + self.team, + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.subquery), + ) + self.assertEqual(str(e.exception), "Could not find a cohort with the name 'blabla'") + + @pytest.mark.usefixtures("unittest_snapshot") + @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False) + def test_in_cohort_conjoined_string(self): + Cohort.objects.create( + team=self.team, + name="my cohort", + is_static=True, + ) + response = execute_hogql_query( + f"SELECT event FROM events WHERE person_id IN COHORT 'my cohort'", + self.team, + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.leftjoin_conjoined), + ) + assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot # type: ignore + + @pytest.mark.usefixtures("unittest_snapshot") + @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False) + def test_in_cohort_conjoined_int(self): + cohort = Cohort.objects.create( + team=self.team, + is_static=True, + ) + response = execute_hogql_query( + f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk}", + self.team, + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.leftjoin_conjoined), + ) + assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot # type: ignore + + @pytest.mark.usefixtures("unittest_snapshot") + @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False) + def test_in_cohort_conjoined_dynamic(self): + random_uuid = self._create_random_events() + cohort = Cohort.objects.create( + team=self.team, + groups=[{"properties": [{"key": "$os", "value": "Chrome", "type": "person"}]}], + ) + recalculate_cohortpeople(cohort, pending_version=0) + response = execute_hogql_query( + f"SELECT event FROM events WHERE person_id IN COHORT {cohort.pk} AND event='{random_uuid}'", + self.team, + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.leftjoin_conjoined), + ) + assert pretty_print_response_in_tests(response, self.team.pk) == self.snapshot # type: ignore + self.assertEqual(len(response.results or []), 1) + self.assertEqual((response.results or [])[0][0], random_uuid) + + @pytest.mark.usefixtures("unittest_snapshot") + @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=True) + def test_in_cohort_conjoined_error(self): + with self.assertRaises(HogQLException) as e: + execute_hogql_query( + f"SELECT event FROM events WHERE person_id IN COHORT true", + self.team, + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.leftjoin_conjoined), + ) self.assertEqual(str(e.exception), "cohort() takes exactly one string or integer argument") with self.assertRaises(HogQLException) as e: execute_hogql_query( f"SELECT event FROM events WHERE person_id IN COHORT 'blabla'", self.team, + modifiers=HogQLQueryModifiers(inCohortVia=InCohortVia.leftjoin_conjoined), ) self.assertEqual(str(e.exception), "Could not find a cohort with the name 'blabla'") diff --git a/posthog/hogql_queries/insights/test/__snapshots__/test_lifecycle_query_runner.ambr b/posthog/hogql_queries/insights/test/__snapshots__/test_lifecycle_query_runner.ambr index ef3b23794866d..2159fb2c49d5b 100644 --- a/posthog/hogql_queries/insights/test/__snapshots__/test_lifecycle_query_runner.ambr +++ b/posthog/hogql_queries/insights/test/__snapshots__/test_lifecycle_query_runner.ambr @@ -79,7 +79,7 @@ WHERE and(equals(events.team_id, 2), greaterOrEquals(toTimeZone(events.timestamp, 'UTC'), minus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-12 00:00:00', 6, 'UTC'))), toIntervalDay(1))), less(toTimeZone(events.timestamp, 'UTC'), plus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-19 23:59:59', 6, 'UTC'))), toIntervalDay(1))), ifNull(in(person_id, (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 4)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 7)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0), equals(events.event, '$pageview')) GROUP BY person_id) diff --git a/posthog/hogql_queries/insights/trends/breakdown.py b/posthog/hogql_queries/insights/trends/breakdown.py index 2f2370de0204f..d719b4b1ca598 100644 --- a/posthog/hogql_queries/insights/trends/breakdown.py +++ b/posthog/hogql_queries/insights/trends/breakdown.py @@ -17,7 +17,7 @@ from posthog.hogql_queries.utils.query_date_range import QueryDateRange from posthog.models.filters.mixins.utils import cached_property from posthog.models.team.team import Team -from posthog.schema import ActionsNode, EventsNode, TrendsQuery +from posthog.schema import ActionsNode, EventsNode, HogQLQueryModifiers, InCohortVia, TrendsQuery class Breakdown: @@ -26,6 +26,7 @@ class Breakdown: series: EventsNode | ActionsNode query_date_range: QueryDateRange timings: HogQLTimings + modifiers: HogQLQueryModifiers events_filter: ast.Expr def __init__( @@ -35,6 +36,7 @@ def __init__( series: EventsNode | ActionsNode, query_date_range: QueryDateRange, timings: HogQLTimings, + modifiers: HogQLQueryModifiers, events_filter: ast.Expr, ): self.team = team @@ -42,6 +44,7 @@ def __init__( self.series = series self.query_date_range = query_date_range self.timings = timings + self.modifiers = modifiers self.events_filter = events_filter @cached_property @@ -70,8 +73,14 @@ def column_expr(self) -> ast.Expr: expr=parse_expr(self.query.breakdownFilter.breakdown), ) elif self.query.breakdownFilter.breakdown_type == "cohort": + if self.modifiers.inCohortVia == InCohortVia.leftjoin_conjoined: + return ast.Alias( + alias="breakdown_value", + expr=ast.Field(chain=["__in_cohort", "cohort_id"]), + ) + cohort_breakdown = ( - 0 if self.query.breakdownFilter.breakdown == "all" else int(self.query.breakdownFilter.breakdown) + 0 if self.query.breakdownFilter.breakdown == "all" else int(self.query.breakdownFilter.breakdown) # type: ignore ) return ast.Alias( alias="breakdown_value", @@ -91,17 +100,44 @@ def column_expr(self) -> ast.Expr: return ast.Alias(alias="breakdown_value", expr=self._get_breakdown_transform_func) def events_where_filter(self) -> ast.Expr | None: - if self.query.breakdownFilter.breakdown_type == "cohort": + if ( + self.query.breakdownFilter is not None + and self.query.breakdownFilter.breakdown is not None + and self.query.breakdownFilter.breakdown_type == "cohort" + ): if self.query.breakdownFilter.breakdown == "all": return None + if isinstance(self.query.breakdownFilter.breakdown, List): + or_clause = ast.Or( + exprs=[ + ast.CompareOperation( + left=ast.Field(chain=["person", "id"]), + op=ast.CompareOperationOp.InCohort, + right=ast.Constant(value=breakdown), + ) + for breakdown in self.query.breakdownFilter.breakdown + ] + ) + if len(self.query.breakdownFilter.breakdown) > 1: + return or_clause + elif len(self.query.breakdownFilter.breakdown) == 1: + return or_clause.exprs[0] + else: + return ast.Constant(value=True) + return ast.CompareOperation( left=ast.Field(chain=["person", "id"]), op=ast.CompareOperationOp.InCohort, - right=ast.Constant(value=int(self.query.breakdownFilter.breakdown)), + right=ast.Constant(value=self.query.breakdownFilter.breakdown), ) - if self.query.breakdownFilter.breakdown_type == "hogql": + if ( + self.query.breakdownFilter is not None + and self.query.breakdownFilter.breakdown is not None + and self.query.breakdownFilter.breakdown_type == "hogql" + and isinstance(self.query.breakdownFilter.breakdown, str) + ): left = parse_expr(self.query.breakdownFilter.breakdown) else: left = ast.Field(chain=self._properties_chain) @@ -176,8 +212,8 @@ def _get_breakdown_values(self) -> List[str | int]: with self.timings.measure("breakdown_values_query"): breakdown = BreakdownValues( team=self.team, - event_name=series_event_name(self.series), - breakdown_field=self.query.breakdownFilter.breakdown, + event_name=series_event_name(self.series) or "", + breakdown_field=self.query.breakdownFilter.breakdown, # type: ignore breakdown_type=self.query.breakdownFilter.breakdown_type, query_date_range=self.query_date_range, events_filter=self.events_filter, diff --git a/posthog/hogql_queries/insights/trends/breakdown_values.py b/posthog/hogql_queries/insights/trends/breakdown_values.py index 870d1d3fb44dd..b1cd87868590a 100644 --- a/posthog/hogql_queries/insights/trends/breakdown_values.py +++ b/posthog/hogql_queries/insights/trends/breakdown_values.py @@ -17,7 +17,7 @@ class BreakdownValues: team: Team event_name: str - breakdown_field: Union[str, float] + breakdown_field: Union[str, float, List[Union[str, float]]] breakdown_type: str query_date_range: QueryDateRange events_filter: ast.Expr @@ -31,7 +31,7 @@ def __init__( self, team: Team, event_name: str, - breakdown_field: Union[str, float], + breakdown_field: Union[str, float, List[Union[str, float]]], query_date_range: QueryDateRange, breakdown_type: str, events_filter: ast.Expr, @@ -58,7 +58,10 @@ def get_breakdown_values(self) -> List[str | int]: if self.breakdown_field == "all": return [0] - return [int(self.breakdown_field)] + if isinstance(self.breakdown_field, List): + return [value if isinstance(value, str) else int(value) for value in self.breakdown_field] + + return [self.breakdown_field if isinstance(self.breakdown_field, str) else int(self.breakdown_field)] if self.breakdown_type == "hogql": select_field = ast.Alias( @@ -70,8 +73,8 @@ def get_breakdown_values(self) -> List[str | int]: alias="value", expr=ast.Field( chain=get_properties_chain( - breakdown_type=self.breakdown_type, - breakdown_field=self.breakdown_field, + breakdown_type=self.breakdown_type, # type: ignore + breakdown_field=str(self.breakdown_field), group_type_index=self.group_type_index, ) ), diff --git a/posthog/hogql_queries/insights/trends/query_builder.py b/posthog/hogql_queries/insights/trends/query_builder.py index e8350f23c4081..90c6d6488046c 100644 --- a/posthog/hogql_queries/insights/trends/query_builder.py +++ b/posthog/hogql_queries/insights/trends/query_builder.py @@ -13,7 +13,7 @@ from posthog.models.action.action import Action from posthog.models.filters.mixins.utils import cached_property from posthog.models.team.team import Team -from posthog.schema import ActionsNode, EventsNode, TrendsQuery +from posthog.schema import ActionsNode, EventsNode, HogQLQueryModifiers, TrendsQuery class TrendsQueryBuilder: @@ -22,6 +22,7 @@ class TrendsQueryBuilder: query_date_range: QueryDateRange series: EventsNode | ActionsNode timings: HogQLTimings + modifiers: HogQLQueryModifiers def __init__( self, @@ -30,12 +31,14 @@ def __init__( query_date_range: QueryDateRange, series: EventsNode | ActionsNode, timings: HogQLTimings, + modifiers: HogQLQueryModifiers, ): self.query = trends_query self.team = team self.query_date_range = query_date_range self.series = series self.timings = timings + self.modifiers = modifiers def build_query(self) -> ast.SelectQuery | ast.SelectUnionQuery: if self._trends_display.should_aggregate_values(): @@ -448,6 +451,7 @@ def _breakdown(self): series=self.series, query_date_range=self.query_date_range, timings=self.timings, + modifiers=self.modifiers, events_filter=self._events_filter(ignore_breakdowns=True), ) diff --git a/posthog/hogql_queries/insights/trends/test/__snapshots__/test_trends.ambr b/posthog/hogql_queries/insights/trends/test/__snapshots__/test_trends.ambr index 68b58b6d9b72e..aa5870b1810f5 100644 --- a/posthog/hogql_queries/insights/trends/test/__snapshots__/test_trends.ambr +++ b/posthog/hogql_queries/insights/trends/test/__snapshots__/test_trends.ambr @@ -85,7 +85,7 @@ WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-01 00:00:00', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-07 23:59:59', 6, 'UTC'))), ifNull(equals(e__pdi__person.`properties___$bool_prop`, 'x'), 0), and(equals(e.event, 'sign up'), ifNull(in(e__pdi.person_id, (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 5)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 8)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0))) GROUP BY day_start) @@ -172,7 +172,7 @@ WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-01 00:00:00', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-07 23:59:59', 6, 'UTC'))), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.person_properties, '$bool_prop'), ''), 'null'), '^"|"$', ''), 'x'), 0), and(equals(e.event, 'sign up'), ifNull(in(ifNull(nullIf(e__override.override_person_id, '00000000-0000-0000-0000-000000000000'), e.person_id), (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 6)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 9)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0))) GROUP BY day_start) @@ -688,7 +688,7 @@ WHERE and(equals(e.team_id, 2), equals(e.event, '$pageview'), and(or(ifNull(equals(e__pdi__person.properties___name, 'p1'), 0), ifNull(equals(e__pdi__person.properties___name, 'p2'), 0), ifNull(equals(e__pdi__person.properties___name, 'p3'), 0)), ifNull(in(e__pdi.person_id, (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 25)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 28)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0))) GROUP BY value @@ -757,7 +757,7 @@ WHERE and(equals(e.team_id, 2), and(and(equals(e.event, '$pageview'), and(or(ifNull(equals(e__pdi__person.properties___name, 'p1'), 0), ifNull(equals(e__pdi__person.properties___name, 'p2'), 0), ifNull(equals(e__pdi__person.properties___name, 'p3'), 0)), ifNull(in(e__pdi.person_id, (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 25)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 28)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0))), or(ifNull(equals(transform(ifNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, 'key'), ''), 'null'), '^"|"$', ''), '$$_posthog_breakdown_null_$$'), ['$$_posthog_breakdown_other_$$', 'val'], ['$$_posthog_breakdown_other_$$', 'val'], '$$_posthog_breakdown_other_$$'), '$$_posthog_breakdown_other_$$'), 0), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, 'key'), ''), 'null'), '^"|"$', ''), 'val'), 0))), ifNull(greaterOrEquals(timestamp, minus(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-01 00:00:00', 6, 'UTC')), toIntervalDay(7))), 0), ifNull(lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-12 23:59:59', 6, 'UTC'))), 0)) GROUP BY timestamp, actor_id, @@ -1592,7 +1592,7 @@ WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 13:01:01', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))), and(equals(e.event, 'sign up'), ifNull(in(e__pdi.person_id, (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 38)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 41)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0))) GROUP BY value @@ -1640,7 +1640,7 @@ WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 13:01:01', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))), and(equals(e.event, 'sign up'), ifNull(in(e__pdi.person_id, (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 38)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 41)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0)), or(ifNull(equals(transform(ifNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', ''), '$$_posthog_breakdown_null_$$'), ['$$_posthog_breakdown_other_$$', 'value', 'other_value'], ['$$_posthog_breakdown_other_$$', 'value', 'other_value'], '$$_posthog_breakdown_other_$$'), '$$_posthog_breakdown_other_$$'), 0), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', ''), 'value'), 0), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', ''), 'other_value'), 0))) GROUP BY day_start, @@ -1691,7 +1691,7 @@ WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 13:01:01', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))), and(equals(e.event, 'sign up'), ifNull(in(ifNull(nullIf(e__override.override_person_id, '00000000-0000-0000-0000-000000000000'), e.person_id), (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 39)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 42)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0))) GROUP BY value @@ -1738,7 +1738,7 @@ WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 13:01:01', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))), and(equals(e.event, 'sign up'), ifNull(in(ifNull(nullIf(e__override.override_person_id, '00000000-0000-0000-0000-000000000000'), e.person_id), (SELECT cohortpeople.person_id AS person_id FROM cohortpeople - WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 39)) + WHERE and(equals(cohortpeople.team_id, 2), equals(cohortpeople.cohort_id, 42)) GROUP BY cohortpeople.person_id, cohortpeople.cohort_id, cohortpeople.version HAVING ifNull(greater(sum(cohortpeople.sign), 0), 0))), 0)), or(ifNull(equals(transform(ifNull(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', ''), '$$_posthog_breakdown_null_$$'), ['$$_posthog_breakdown_other_$$', 'value', 'other_value'], ['$$_posthog_breakdown_other_$$', 'value', 'other_value'], '$$_posthog_breakdown_other_$$'), '$$_posthog_breakdown_other_$$'), 0), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', ''), 'value'), 0), ifNull(equals(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', ''), 'other_value'), 0))) GROUP BY day_start, diff --git a/posthog/hogql_queries/insights/trends/test/test_query_builder.py b/posthog/hogql_queries/insights/trends/test/test_query_builder.py index 31457a8a5bb2c..c978c9f2767a5 100644 --- a/posthog/hogql_queries/insights/trends/test/test_query_builder.py +++ b/posthog/hogql_queries/insights/trends/test/test_query_builder.py @@ -1,5 +1,6 @@ from datetime import datetime from freezegun import freeze_time +from posthog.hogql.modifiers import create_default_modifiers_for_team from posthog.hogql.query import execute_hogql_query from posthog.hogql.timings import HogQLTimings @@ -45,6 +46,7 @@ def get_response(self, trends_query: TrendsQuery) -> HogQLQueryResponse: ) timings = HogQLTimings() + modifiers = create_default_modifiers_for_team(self.team) query_builder = TrendsQueryBuilder( trends_query=trends_query, @@ -52,6 +54,7 @@ def get_response(self, trends_query: TrendsQuery) -> HogQLQueryResponse: query_date_range=query_date_range, series=trends_query.series[0], timings=timings, + modifiers=modifiers, ) query = query_builder.build_query() diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index 9b38c6e644404..cb4b5f5ca654d 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -39,6 +39,7 @@ ChartDisplayType, EventsNode, HogQLQueryResponse, + InCohortVia, TrendsQuery, TrendsQueryResponse, HogQLQueryModifiers, @@ -83,7 +84,7 @@ def _refresh_frequency(self): return refresh_frequency - def to_query(self) -> List[ast.SelectQuery]: + def to_query(self) -> List[ast.SelectQuery | ast.SelectUnionQuery]: # type: ignore queries = [] with self.timings.measure("trends_query"): for series in self.series: @@ -98,6 +99,7 @@ def to_query(self) -> List[ast.SelectQuery]: query_date_range=query_date_range, series=series.series, timings=self.timings, + modifiers=self.modifiers, ) queries.append(query_builder.build_query()) @@ -118,6 +120,7 @@ def to_actors_query(self) -> ast.SelectQuery | ast.SelectUnionQuery: query_date_range=query_date_range, series=series.series, timings=self.timings, + modifiers=self.modifiers, ) queries.append(query_builder.build_persons_query()) @@ -341,17 +344,21 @@ def setup_series(self) -> List[SeriesWithExtras]: for series in self.query.series ] - if self.query.breakdownFilter is not None and self.query.breakdownFilter.breakdown_type == "cohort": + if ( + self.modifiers.inCohortVia != InCohortVia.leftjoin_conjoined + and self.query.breakdownFilter is not None + and self.query.breakdownFilter.breakdown_type == "cohort" + ): updated_series = [] if isinstance(self.query.breakdownFilter.breakdown, List): cohort_ids = self.query.breakdownFilter.breakdown else: - cohort_ids = [self.query.breakdownFilter.breakdown] + cohort_ids = [self.query.breakdownFilter.breakdown] # type: ignore for cohort_id in cohort_ids: for series in series_with_extras: copied_query = deepcopy(self.query) - copied_query.breakdownFilter.breakdown = cohort_id + copied_query.breakdownFilter.breakdown = cohort_id # type: ignore updated_series.append( SeriesWithExtras( diff --git a/posthog/hogql_queries/insights/trends/utils.py b/posthog/hogql_queries/insights/trends/utils.py index 1510a87a76bef..cd877757b2e24 100644 --- a/posthog/hogql_queries/insights/trends/utils.py +++ b/posthog/hogql_queries/insights/trends/utils.py @@ -12,7 +12,7 @@ def get_properties_chain( breakdown_type: Union[Literal["person"], Literal["session"], Literal["group"], Literal["event"]], breakdown_field: str, group_type_index: Optional[float | int], -) -> List[str]: +) -> List[str | int]: if breakdown_type == "person": return ["person", "properties", breakdown_field] diff --git a/posthog/schema.py b/posthog/schema.py index ddee598cbd80b..c8df43c2e0345 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -267,6 +267,7 @@ class HogQLNotice(BaseModel): class InCohortVia(str, Enum): leftjoin = "leftjoin" subquery = "subquery" + leftjoin_conjoined = "leftjoin_conjoined" class MaterializationMode(str, Enum):