Skip to content

Commit

Permalink
feat: unknown type arg matching (#25500)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
aspicer and github-actions[bot] authored Oct 10, 2024
1 parent 8c5d063 commit 7ceac1f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 15 deletions.
80 changes: 66 additions & 14 deletions posthog/hogql/functions/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import chain
from typing import Optional


from posthog.cloud_utils import is_cloud, is_ci
from posthog.hogql import ast
from posthog.hogql.ast import (
Expand Down Expand Up @@ -80,16 +81,13 @@ class HogQLFunctionMeta:


def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantType, ...]):
_sig_arg_types = list(sig_arg_types)
if len(arg_types) != len(sig_arg_types):
return False

for index, arg_type in enumerate(arg_types):
_sig_arg_type = _sig_arg_types[index]
if not isinstance(arg_type, _sig_arg_type.__class__):
return False

return True
return all(
isinstance(sig_arg_type, UnknownType) or isinstance(arg_type, sig_arg_type.__class__)
for arg_type, sig_arg_type in zip(arg_types, sig_arg_types)
)


HOGQL_COMPARISON_MAPPING: dict[str, ast.CompareOperationOp] = {
Expand Down Expand Up @@ -448,15 +446,59 @@ def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantTy
"toStartOfYear": HogQLFunctionMeta("toStartOfYear", 1, 1),
"toStartOfISOYear": HogQLFunctionMeta("toStartOfISOYear", 1, 1),
"toStartOfQuarter": HogQLFunctionMeta("toStartOfQuarter", 1, 1),
"toStartOfMonth": HogQLFunctionMeta("toStartOfMonth", 1, 1),
"toStartOfMonth": HogQLFunctionMeta(
"toStartOfMonth",
1,
1,
signatures=[
((UnknownType(),), DateType()),
],
),
"toLastDayOfMonth": HogQLFunctionMeta("toLastDayOfMonth", 1, 1),
"toMonday": HogQLFunctionMeta("toMonday", 1, 1),
"toStartOfWeek": HogQLFunctionMeta("toStartOfWeek", 1, 2),
"toStartOfDay": HogQLFunctionMeta("toStartOfDay", 1, 2),
"toStartOfWeek": HogQLFunctionMeta(
"toStartOfWeek",
1,
2,
signatures=[
((UnknownType(),), DateType()),
((UnknownType(), UnknownType()), DateType()),
],
),
"toStartOfDay": HogQLFunctionMeta(
"toStartOfDay",
1,
2,
signatures=[
((UnknownType(),), DateTimeType()),
((UnknownType(), UnknownType()), DateTimeType()),
],
),
"toLastDayOfWeek": HogQLFunctionMeta("toLastDayOfWeek", 1, 2),
"toStartOfHour": HogQLFunctionMeta("toStartOfHour", 1, 1),
"toStartOfMinute": HogQLFunctionMeta("toStartOfMinute", 1, 1),
"toStartOfSecond": HogQLFunctionMeta("toStartOfSecond", 1, 1),
"toStartOfHour": HogQLFunctionMeta(
"toStartOfHour",
1,
1,
signatures=[
((UnknownType(),), DateTimeType()),
],
),
"toStartOfMinute": HogQLFunctionMeta(
"toStartOfMinute",
1,
1,
signatures=[
((UnknownType(),), DateTimeType()),
],
),
"toStartOfSecond": HogQLFunctionMeta(
"toStartOfSecond",
1,
1,
signatures=[
((UnknownType(),), DateTimeType()),
],
),
"toStartOfFiveMinutes": HogQLFunctionMeta("toStartOfFiveMinutes", 1, 1),
"toStartOfTenMinutes": HogQLFunctionMeta("toStartOfTenMinutes", 1, 1),
"toStartOfFifteenMinutes": HogQLFunctionMeta("toStartOfFifteenMinutes", 1, 1),
Expand All @@ -472,7 +514,17 @@ def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantTy
"dateSub": HogQLFunctionMeta("dateSub", 3, 3),
"timeStampAdd": HogQLFunctionMeta("timeStampAdd", 2, 2),
"timeStampSub": HogQLFunctionMeta("timeStampSub", 2, 2),
"now": HogQLFunctionMeta("now64", 0, 1, tz_aware=True, case_sensitive=False),
"now": HogQLFunctionMeta(
"now64",
0,
1,
tz_aware=True,
case_sensitive=False,
signatures=[
((), DateTimeType()),
((UnknownType(),), DateTimeType()),
],
),
"nowInBlock": HogQLFunctionMeta("nowInBlock", 1, 1),
"rowNumberInBlock": HogQLFunctionMeta("rowNumberInBlock", 0, 0),
"rowNumberInAllBlocks": HogQLFunctionMeta("rowNumberInAllBlocks", 0, 0),
Expand Down
30 changes: 29 additions & 1 deletion posthog/hogql/test/test_mapping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from posthog.hogql.ast import FloatType, IntegerType
from posthog.hogql.ast import FloatType, IntegerType, DateType
from posthog.hogql.base import UnknownType
from posthog.hogql.context import HogQLContext
from posthog.hogql.parser import parse_expr
from posthog.hogql.printer import print_ast
from posthog.test.base import BaseTest
from typing import Optional
from posthog.hogql.functions.mapping import (
Expand All @@ -7,6 +11,7 @@
find_hogql_aggregation,
find_hogql_posthog_function,
HogQLFunctionMeta,
HOGQL_CLICKHOUSE_FUNCTIONS,
)


Expand Down Expand Up @@ -60,3 +65,26 @@ def test_compare_types_mismatch_lengths(self):
def test_compare_types_mismatch_differing_order(self):
res = compare_types([IntegerType(), FloatType()], (FloatType(), IntegerType()))
assert res is False

def test_unknown_type_mapping(self):
HOGQL_CLICKHOUSE_FUNCTIONS["overloadedFunction"] = HogQLFunctionMeta(
"overloadFailure",
1,
1,
overloads=[((DateType,), "overloadSuccess")],
)

HOGQL_CLICKHOUSE_FUNCTIONS["dateEmittingFunction"] = HogQLFunctionMeta(
"dateEmittingFunction",
1,
1,
signatures=[
((UnknownType(),), DateType()),
],
)
ast = print_ast(
parse_expr("overloadedFunction(dateEmittingFunction('123123'))"),
HogQLContext(self.team.pk, enable_select_queries=True),
"clickhouse",
)
assert "overloadSuccess" in ast

0 comments on commit 7ceac1f

Please sign in to comment.