diff --git a/posthog/hogql/functions/mapping.py b/posthog/hogql/functions/mapping.py index cc7124adca7ce..08bee305b933e 100644 --- a/posthog/hogql/functions/mapping.py +++ b/posthog/hogql/functions/mapping.py @@ -2,6 +2,7 @@ from itertools import chain from typing import Optional + from posthog.cloud_utils import is_cloud, is_ci from posthog.hogql import ast from posthog.hogql.ast import ( @@ -80,16 +81,13 @@ class HogQLFunctionMeta: def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantType, ...]): - _sig_arg_types = list(sig_arg_types) if len(arg_types) != len(sig_arg_types): return False - for index, arg_type in enumerate(arg_types): - _sig_arg_type = _sig_arg_types[index] - if not isinstance(arg_type, _sig_arg_type.__class__): - return False - - return True + return all( + isinstance(sig_arg_type, UnknownType) or isinstance(arg_type, sig_arg_type.__class__) + for arg_type, sig_arg_type in zip(arg_types, sig_arg_types) + ) HOGQL_COMPARISON_MAPPING: dict[str, ast.CompareOperationOp] = { @@ -448,15 +446,59 @@ def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantTy "toStartOfYear": HogQLFunctionMeta("toStartOfYear", 1, 1), "toStartOfISOYear": HogQLFunctionMeta("toStartOfISOYear", 1, 1), "toStartOfQuarter": HogQLFunctionMeta("toStartOfQuarter", 1, 1), - "toStartOfMonth": HogQLFunctionMeta("toStartOfMonth", 1, 1), + "toStartOfMonth": HogQLFunctionMeta( + "toStartOfMonth", + 1, + 1, + signatures=[ + ((UnknownType(),), DateType()), + ], + ), "toLastDayOfMonth": HogQLFunctionMeta("toLastDayOfMonth", 1, 1), "toMonday": HogQLFunctionMeta("toMonday", 1, 1), - "toStartOfWeek": HogQLFunctionMeta("toStartOfWeek", 1, 2), - "toStartOfDay": HogQLFunctionMeta("toStartOfDay", 1, 2), + "toStartOfWeek": HogQLFunctionMeta( + "toStartOfWeek", + 1, + 2, + signatures=[ + ((UnknownType(),), DateType()), + ((UnknownType(), UnknownType()), DateType()), + ], + ), + "toStartOfDay": HogQLFunctionMeta( + "toStartOfDay", + 1, + 2, + signatures=[ + ((UnknownType(),), DateTimeType()), + ((UnknownType(), UnknownType()), DateTimeType()), + ], + ), "toLastDayOfWeek": HogQLFunctionMeta("toLastDayOfWeek", 1, 2), - "toStartOfHour": HogQLFunctionMeta("toStartOfHour", 1, 1), - "toStartOfMinute": HogQLFunctionMeta("toStartOfMinute", 1, 1), - "toStartOfSecond": HogQLFunctionMeta("toStartOfSecond", 1, 1), + "toStartOfHour": HogQLFunctionMeta( + "toStartOfHour", + 1, + 1, + signatures=[ + ((UnknownType(),), DateTimeType()), + ], + ), + "toStartOfMinute": HogQLFunctionMeta( + "toStartOfMinute", + 1, + 1, + signatures=[ + ((UnknownType(),), DateTimeType()), + ], + ), + "toStartOfSecond": HogQLFunctionMeta( + "toStartOfSecond", + 1, + 1, + signatures=[ + ((UnknownType(),), DateTimeType()), + ], + ), "toStartOfFiveMinutes": HogQLFunctionMeta("toStartOfFiveMinutes", 1, 1), "toStartOfTenMinutes": HogQLFunctionMeta("toStartOfTenMinutes", 1, 1), "toStartOfFifteenMinutes": HogQLFunctionMeta("toStartOfFifteenMinutes", 1, 1), @@ -472,7 +514,17 @@ def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantTy "dateSub": HogQLFunctionMeta("dateSub", 3, 3), "timeStampAdd": HogQLFunctionMeta("timeStampAdd", 2, 2), "timeStampSub": HogQLFunctionMeta("timeStampSub", 2, 2), - "now": HogQLFunctionMeta("now64", 0, 1, tz_aware=True, case_sensitive=False), + "now": HogQLFunctionMeta( + "now64", + 0, + 1, + tz_aware=True, + case_sensitive=False, + signatures=[ + ((), DateTimeType()), + ((UnknownType(),), DateTimeType()), + ], + ), "nowInBlock": HogQLFunctionMeta("nowInBlock", 1, 1), "rowNumberInBlock": HogQLFunctionMeta("rowNumberInBlock", 0, 0), "rowNumberInAllBlocks": HogQLFunctionMeta("rowNumberInAllBlocks", 0, 0), diff --git a/posthog/hogql/test/test_mapping.py b/posthog/hogql/test/test_mapping.py index 6e074687d22cf..8dc805dace346 100644 --- a/posthog/hogql/test/test_mapping.py +++ b/posthog/hogql/test/test_mapping.py @@ -1,4 +1,8 @@ -from posthog.hogql.ast import FloatType, IntegerType +from posthog.hogql.ast import FloatType, IntegerType, DateType +from posthog.hogql.base import UnknownType +from posthog.hogql.context import HogQLContext +from posthog.hogql.parser import parse_expr +from posthog.hogql.printer import print_ast from posthog.test.base import BaseTest from typing import Optional from posthog.hogql.functions.mapping import ( @@ -7,6 +11,7 @@ find_hogql_aggregation, find_hogql_posthog_function, HogQLFunctionMeta, + HOGQL_CLICKHOUSE_FUNCTIONS, ) @@ -60,3 +65,26 @@ def test_compare_types_mismatch_lengths(self): def test_compare_types_mismatch_differing_order(self): res = compare_types([IntegerType(), FloatType()], (FloatType(), IntegerType())) assert res is False + + def test_unknown_type_mapping(self): + HOGQL_CLICKHOUSE_FUNCTIONS["overloadedFunction"] = HogQLFunctionMeta( + "overloadFailure", + 1, + 1, + overloads=[((DateType,), "overloadSuccess")], + ) + + HOGQL_CLICKHOUSE_FUNCTIONS["dateEmittingFunction"] = HogQLFunctionMeta( + "dateEmittingFunction", + 1, + 1, + signatures=[ + ((UnknownType(),), DateType()), + ], + ) + ast = print_ast( + parse_expr("overloadedFunction(dateEmittingFunction('123123'))"), + HogQLContext(self.team.pk, enable_select_queries=True), + "clickhouse", + ) + assert "overloadSuccess" in ast