diff --git a/posthog/hogql_queries/hogql_cohort_query.py b/posthog/hogql_queries/hogql_cohort_query.py index 1c5ab7761a46d..d6727b3624ae8 100644 --- a/posthog/hogql_queries/hogql_cohort_query.py +++ b/posthog/hogql_queries/hogql_cohort_query.py @@ -159,7 +159,9 @@ def _get_series(self, prop: Property, math=None): else: raise ValueError(f"Event type must be 'events' or 'actions'") - def _actors_query_from_source(self, source: Union[InsightActorsQuery, FunnelsActorsQuery]) -> ast.SelectQuery: + def _actors_query_from_source( + self, source: Union[InsightActorsQuery, FunnelsActorsQuery, StickinessActorsQuery] + ) -> ast.SelectQuery: actors_query = ActorsQuery( source=source, select=["id"], @@ -217,21 +219,23 @@ def get_performed_event_multiple(self, prop: Property) -> ast.SelectQuery: funnelStep = count + 1 elif prop.operator == "lt": funnelCustomSteps = list(range(1, count)) - elif prop.operator == "eq" or prop.operator == "exact" or prop.operator is None: + elif ( + prop.operator == "exact" or prop.operator is None + ): # mypy refuses this: if any errors crop up prop.operator == "eq" # People who dropped out at count + 1 funnelStep = -(count + 1) else: - raise ValidationError("count_operator must be gte, lte, eq, or None") + raise ValidationError("count_operator must be gt(e), lt(e), exact, or None") if prop.event_filters: - filter = Filter(data={"properties": prop.event_filters}).property_groups + property_group = Filter(data={"properties": prop.event_filters}).property_groups # TODO: this is testing - we need to figure out how to handle ORs here - if isinstance(filter, PropertyGroup): - if filter.type == PropertyOperatorType.OR: + if isinstance(property_group.values[0], PropertyGroup): + if property_group.type == PropertyOperatorType.OR: raise Exception("Don't support OR at the event level") - series[0].properties = filter.values + series[0].properties = property_group.values else: - series[0].properties = filter + series[0].properties = cast(list[Property], property_group.values) if prop.explicit_datetime: date_from = prop.explicit_datetime @@ -253,7 +257,9 @@ def get_performed_event_multiple(self, prop: Property) -> ast.SelectQuery: def get_performed_event_sequence(self, prop: Property) -> ast.SelectQuery: # either an action or an event - series = [] + series: list[EventsNode | ActionsNode] = [] + assert prop.seq_event is not None + if prop.event_type == "events": series.append(EventsNode(event=prop.key)) elif prop.event_type == "actions": @@ -268,12 +274,6 @@ def get_performed_event_sequence(self, prop: Property) -> ast.SelectQuery: else: raise ValueError(f"Event type must be 'events' or 'actions'") - """ - if prop.event_filters: - filter = Filter(data={"properties": prop.event_filters}).property_groups - series[0].properties = filter - """ - if prop.explicit_datetime: date_from = prop.explicit_datetime else: @@ -361,7 +361,7 @@ def get_restarted_performing_event(self, prop: Property) -> ast.SelectSetQuery: ], ) - def get_performed_event_regularly(self, prop: Property) -> ast.SelectSetQuery: + def get_performed_event_regularly(self, prop: Property) -> ast.SelectQuery: # min_periods # operator (gte) # operator_value (int) @@ -369,21 +369,33 @@ def get_performed_event_regularly(self, prop: Property) -> ast.SelectSetQuery: # time_value # total periods - series = self._get_series(prop) + # this isn't correct - write a test for it, not honoring total periods - date_value = parse_and_validate_positive_integer(prop.time_value, "time_value") date_interval = validate_interval(prop.time_interval) + date_value = parse_and_validate_positive_integer(prop.time_value, "time_value") + operator_value = parse_and_validate_positive_integer(prop.operator_value, "operator_value") + min_period_count = parse_and_validate_positive_integer(prop.min_periods, "min_periods") + total_period_count = parse_and_validate_positive_integer(prop.total_periods, "total_periods") + if min_period_count > total_period_count: + raise ( + ValueError( + f"min_periods ({min_period_count}) cannot be greater than total_periods ({total_period_count})" + ) + ) + + series = self._get_series(prop) + date_from = f"-{date_value}{date_interval[:1]}" stickiness_query = StickinessQuery( series=series, dateRange=DateRange(date_from=date_from), stickinessFilter=StickinessFilter( - stickinessCriteria=StickinessCriteria(operator=prop.operator, value=prop.operator_value) + stickinessCriteria=StickinessCriteria(operator=prop.operator, value=operator_value) ), ) return self._actors_query_from_source( - StickinessActorsQuery(source=stickiness_query, day=prop.min_periods - 1, operator=prop.operator) + StickinessActorsQuery(source=stickiness_query, day=min_period_count - 1, operator=prop.operator) ) def get_person_condition(self, prop: Property) -> ast.SelectQuery: @@ -401,8 +413,11 @@ def get_person_condition(self, prop: Property) -> ast.SelectQuery: def get_static_cohort_condition(self, prop: Property) -> ast.SelectQuery: cohort = Cohort.objects.get(pk=cast(int, prop.value)) - return parse_select( - f"SELECT person_id FROM static_cohort_people WHERE cohort_id = {cohort.pk} AND team_id = {self.team.pk}", + return cast( + ast.SelectQuery, + parse_select( + f"SELECT person_id FROM static_cohort_people WHERE cohort_id = {cohort.pk} AND team_id = {self.team.pk}", + ), ) def _get_condition_for_property(self, prop: Property) -> ast.SelectQuery | ast.SelectSetQuery: @@ -421,6 +436,8 @@ def _get_condition_for_property(self, prop: Property) -> ast.SelectQuery | ast.S return self.get_restarted_performing_event(prop) elif prop.value == "performed_event_regularly": return self.get_performed_event_regularly(prop) + else: + raise ValueError(f"Invalid behavioral property value for Cohort: {prop.value}") elif prop.type == "person": return self.get_person_condition(prop) elif ( @@ -433,7 +450,7 @@ def _get_condition_for_property(self, prop: Property) -> ast.SelectQuery | ast.S def _get_conditions(self) -> ast.SelectQuery | ast.SelectSetQuery: def build_conditions( prop: Optional[Union[PropertyGroup, Property]], - ) -> (None | ast.SelectQuery | ast.SelectSetQuery, bool): + ) -> tuple[None | ast.SelectQuery | ast.SelectSetQuery, bool]: if not prop: # What do we do here? return (None, False) @@ -441,7 +458,7 @@ def build_conditions( if isinstance(prop, PropertyGroup): queries = [] for property in prop.values: - query, negation = build_conditions(property) # type: ignore + query, negation = build_conditions(property) if query is not None: queries.append((query, negation)) @@ -492,7 +509,10 @@ def build_conditions( all_negated or negated, ) else: - return (self._get_condition_for_property(prop), prop.negation) + return (self._get_condition_for_property(prop), prop.negation or False) + if self._outer_property_groups is None: + return parse_select("SELECT NULL WHERE 0") conditions, _ = build_conditions(self._outer_property_groups) + assert conditions is not None return conditions