Skip to content

Commit

Permalink
feat: Add IntervalType to HogQL types, extend type signatures for t…
Browse files Browse the repository at this point in the history
…ime arithmetic and comparison (#25640)
  • Loading branch information
tkaemming authored Oct 23, 2024
1 parent 3d5b183 commit 23f3b19
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 57 deletions.
8 changes: 8 additions & 0 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,14 @@ def print_type(self) -> str:
return "DateTime"


@dataclass(kw_only=True)
class IntervalType(ConstantType):
data_type: ConstantDataType = field(default="unknown", init=False)

def print_type(self) -> str:
return "IntervalType"


@dataclass(kw_only=True)
class UUIDType(ConstantType):
data_type: ConstantDataType = field(default="uuid", init=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_select_with_timestamp(self):
FROM
sessions
WHERE
and(equals(sessions.team_id, <TEAM_ID>), ifNull(greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_1)s), toIntervalDay(3)), %(hogql_val_2)s), 0))
and(equals(sessions.team_id, <TEAM_ID>), greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_1)s), toIntervalDay(3)), %(hogql_val_2)s))
GROUP BY
sessions.session_id,
sessions.session_id) AS sessions
Expand Down Expand Up @@ -379,7 +379,7 @@ def test_join_with_events(self):
FROM
sessions
WHERE
and(equals(sessions.team_id, <TEAM_ID>), ifNull(greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_0)s), toIntervalDay(3)), %(hogql_val_1)s), 0))
and(equals(sessions.team_id, <TEAM_ID>), greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_0)s), toIntervalDay(3)), %(hogql_val_1)s))
GROUP BY
sessions.session_id,
sessions.session_id) AS sessions ON equals(events.`$session_id`, sessions.session_id)
Expand Down Expand Up @@ -495,7 +495,7 @@ def test_session_breakdown(self):
FROM
sessions
WHERE
and(equals(sessions.team_id, <TEAM_ID>), ifNull(greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_3)s), toIntervalDay(3)), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_4)s, 6, %(hogql_val_5)s)))), 0), ifNull(lessOrEquals(minus(toTimeZone(sessions.min_timestamp, %(hogql_val_6)s), toIntervalDay(3)), assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_7)s, 6, %(hogql_val_8)s))), 0))
and(equals(sessions.team_id, <TEAM_ID>), greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_3)s), toIntervalDay(3)), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_4)s, 6, %(hogql_val_5)s)))), lessOrEquals(minus(toTimeZone(sessions.min_timestamp, %(hogql_val_6)s), toIntervalDay(3)), assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_7)s, 6, %(hogql_val_8)s))))
GROUP BY
sessions.session_id,
sessions.session_id) AS e__session ON equals(e.`$session_id`, e__session.session_id)
Expand Down Expand Up @@ -537,7 +537,7 @@ def test_session_replay_query(self):
FROM
sessions
WHERE
and(equals(sessions.team_id, <TEAM_ID>), ifNull(greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_2)s), toIntervalDay(3)), %(hogql_val_3)s), 0), ifNull(lessOrEquals(minus(toTimeZone(sessions.min_timestamp, %(hogql_val_4)s), toIntervalDay(3)), now64(6, %(hogql_val_5)s)), 0))
and(equals(sessions.team_id, <TEAM_ID>), greaterOrEquals(plus(toTimeZone(sessions.min_timestamp, %(hogql_val_2)s), toIntervalDay(3)), %(hogql_val_3)s), lessOrEquals(minus(toTimeZone(sessions.min_timestamp, %(hogql_val_4)s), toIntervalDay(3)), now64(6, %(hogql_val_5)s)))
GROUP BY
sessions.session_id,
sessions.session_id) AS s__session ON equals(s.session_id, s__session.session_id)
Expand Down
87 changes: 78 additions & 9 deletions posthog/hogql/functions/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DateTimeType,
DateType,
FloatType,
IntervalType,
StringType,
TupleType,
IntegerType,
Expand Down Expand Up @@ -57,6 +58,7 @@ def validate_function_args(
| UnknownType
| IntegerType
| FloatType
| IntervalType
)


Expand Down Expand Up @@ -124,6 +126,8 @@ def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantTy
),
((DateTimeType(), IntegerType()), DateTimeType()),
((IntegerType(), DateTimeType()), DateTimeType()),
((DateTimeType(), IntervalType()), DateTimeType()),
((IntervalType(), DateTimeType()), DateTimeType()),
],
),
"minus": HogQLFunctionMeta(
Expand All @@ -143,6 +147,8 @@ def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantTy
),
((DateTimeType(), IntegerType()), DateTimeType()),
((IntegerType(), DateTimeType()), DateTimeType()),
((DateTimeType(), IntervalType()), DateTimeType()),
((IntervalType(), DateTimeType()), DateTimeType()),
],
),
"multiply": HogQLFunctionMeta(
Expand Down Expand Up @@ -565,14 +571,70 @@ def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantTy
),
"toModifiedJulianDay": HogQLFunctionMeta("toModifiedJulianDayOrNull", 1, 1),
"fromModifiedJulianDay": HogQLFunctionMeta("fromModifiedJulianDayOrNull", 1, 1),
"toIntervalSecond": HogQLFunctionMeta("toIntervalSecond", 1, 1),
"toIntervalMinute": HogQLFunctionMeta("toIntervalMinute", 1, 1),
"toIntervalHour": HogQLFunctionMeta("toIntervalHour", 1, 1),
"toIntervalDay": HogQLFunctionMeta("toIntervalDay", 1, 1),
"toIntervalWeek": HogQLFunctionMeta("toIntervalWeek", 1, 1),
"toIntervalMonth": HogQLFunctionMeta("toIntervalMonth", 1, 1),
"toIntervalQuarter": HogQLFunctionMeta("toIntervalQuarter", 1, 1),
"toIntervalYear": HogQLFunctionMeta("toIntervalYear", 1, 1),
"toIntervalSecond": HogQLFunctionMeta(
"toIntervalSecond",
1,
1,
signatures=[
((IntegerType(),), IntervalType()),
],
),
"toIntervalMinute": HogQLFunctionMeta(
"toIntervalMinute",
1,
1,
signatures=[
((IntegerType(),), IntervalType()),
],
),
"toIntervalHour": HogQLFunctionMeta(
"toIntervalHour",
1,
1,
signatures=[
((IntegerType(),), IntervalType()),
],
),
"toIntervalDay": HogQLFunctionMeta(
"toIntervalDay",
1,
1,
signatures=[
((IntegerType(),), IntervalType()),
],
),
"toIntervalWeek": HogQLFunctionMeta(
"toIntervalWeek",
1,
1,
signatures=[
((IntegerType(),), IntervalType()),
],
),
"toIntervalMonth": HogQLFunctionMeta(
"toIntervalMonth",
1,
1,
signatures=[
((IntegerType(),), IntervalType()),
],
),
"toIntervalQuarter": HogQLFunctionMeta(
"toIntervalQuarter",
1,
1,
signatures=[
((IntegerType(),), IntervalType()),
],
),
"toIntervalYear": HogQLFunctionMeta(
"toIntervalYear",
1,
1,
signatures=[
((IntegerType(),), IntervalType()),
],
),
# strings
"left": HogQLFunctionMeta("left", 2, 2),
"right": HogQLFunctionMeta("right", 2, 2),
Expand Down Expand Up @@ -830,7 +892,14 @@ def compare_types(arg_types: list[ConstantType], sig_arg_types: tuple[ConstantTy
"coalesce": HogQLFunctionMeta("coalesce", 1, None, case_sensitive=False),
"ifnull": HogQLFunctionMeta("ifNull", 2, 2, case_sensitive=False),
"nullif": HogQLFunctionMeta("nullIf", 2, 2, case_sensitive=False),
"assumeNotNull": HogQLFunctionMeta("assumeNotNull", 1, 1),
"assumeNotNull": HogQLFunctionMeta(
"assumeNotNull",
1,
1,
signatures=[
((DateTimeType(),), DateTimeType()),
],
),
"toNullable": HogQLFunctionMeta("toNullable", 1, 1),
# tuples
"tuple": HogQLFunctionMeta("tuple", 0, None),
Expand Down
4 changes: 4 additions & 0 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,10 @@ def _is_nullable(self, node: ast.Expr) -> bool:
return node.value is None
elif isinstance(node.type, ast.PropertyType):
return True
elif isinstance(node.type, ast.ConstantType):
return node.type.nullable
elif isinstance(node.type, ast.CallType):
return node.type.return_type.nullable
elif isinstance(node.type, ast.FieldType):
return node.type.is_nullable(self.context)
elif isinstance(node, ast.Alias):
Expand Down
22 changes: 2 additions & 20 deletions posthog/hogql/test/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,9 +1359,6 @@ def test_field_nullable_equals(self):
generated_sql_statements1 = self._select(
"SELECT "
"start_time = toStartOfMonth(now()), "
"now() = now(), "
"1 = now(), "
"now() = 1, "
"1 = 1, "
"click_count = 1, "
"1 = click_count, "
Expand All @@ -1373,9 +1370,6 @@ def test_field_nullable_equals(self):
generated_sql_statements2 = self._select(
"SELECT "
"equals(start_time, toStartOfMonth(now())), "
"equals(now(), now()), "
"equals(1, now()), "
"equals(now(), 1), "
"equals(1, 1), "
"equals(click_count, 1), "
"equals(1, click_count), "
Expand All @@ -1391,12 +1385,6 @@ def test_field_nullable_equals(self):
# (the return of toStartOfMonth() is treated as "potentially nullable" since we yet have full typing support)
f"ifNull(equals(session_replay_events.start_time, toStartOfMonth(now64(6, %(hogql_val_1)s))), "
f"isNull(session_replay_events.start_time) and isNull(toStartOfMonth(now64(6, %(hogql_val_1)s)))), "
# now() = now() (also two nullable fields)
f"ifNull(equals(now64(6, %(hogql_val_2)s), now64(6, %(hogql_val_3)s)), isNull(now64(6, %(hogql_val_2)s)) and isNull(now64(6, %(hogql_val_3)s))), "
# 1 = now()
f"ifNull(equals(1, now64(6, %(hogql_val_4)s)), 0), "
# now() = 1
f"ifNull(equals(now64(6, %(hogql_val_5)s), 1), 0), "
# 1 = 1
f"1, "
# click_count = 1
Expand All @@ -1415,12 +1403,12 @@ def test_field_nullable_equals(self):

def test_field_nullable_not_equals(self):
generated_sql1 = self._select(
"SELECT start_time != toStartOfMonth(now()), now() != now(), 1 != now(), now() != 1, 1 != 1, "
"SELECT start_time != toStartOfMonth(now()), 1 != 1, "
"click_count != 1, 1 != click_count, click_count != keypress_count, click_count != null, null != click_count "
"FROM session_replay_events"
)
generated_sql2 = self._select(
"SELECT notEquals(start_time, toStartOfMonth(now())), notEquals(now(), now()), notEquals(1, now()), notEquals(now(), 1), notEquals(1, 1), "
"SELECT notEquals(start_time, toStartOfMonth(now())), notEquals(1, 1), "
"notEquals(click_count, 1), notEquals(1, click_count), notEquals(click_count, keypress_count), notEquals(click_count, null), notEquals(null, click_count) "
"FROM session_replay_events"
)
Expand All @@ -1431,12 +1419,6 @@ def test_field_nullable_not_equals(self):
# (the return of toStartOfMonth() is treated as "potentially nullable" since we yet have full typing support)
f"ifNull(notEquals(session_replay_events.start_time, toStartOfMonth(now64(6, %(hogql_val_1)s))), "
f"isNotNull(session_replay_events.start_time) or isNotNull(toStartOfMonth(now64(6, %(hogql_val_1)s)))), "
# now() = now() (also two nullable fields)
f"ifNull(notEquals(now64(6, %(hogql_val_2)s), now64(6, %(hogql_val_3)s)), isNotNull(now64(6, %(hogql_val_2)s)) or isNotNull(now64(6, %(hogql_val_3)s))), "
# 1 = now()
f"ifNull(notEquals(1, now64(6, %(hogql_val_4)s)), 1), "
# now() = 1
f"ifNull(notEquals(now64(6, %(hogql_val_5)s), 1), 1), "
# 1 = 1
f"0, "
# click_count = 1
Expand Down
23 changes: 23 additions & 0 deletions posthog/hogql/test/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,29 @@ def test_function_types(self):
node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))
self._assert_first_columm_is_type(node, ast.IntegerType(nullable=False))

def test_assume_not_null_type(self):
node = self._select(f"SELECT assumeNotNull(toDateTime('2020-01-01 00:00:00'))")
node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))

[selected] = node.select
assert isinstance(selected.type, ast.CallType)
assert selected.type.return_type == ast.DateTimeType(nullable=False)

def test_interval_type_arithmetic(self):
operators = ["+", "-"]
granularites = ["Second", "Minute", "Hour", "Day", "Week", "Month", "Quarter", "Year"]
exprs = []
for granularity in granularites:
for operator in operators:
exprs.append(f"timestamp {operator} toInterval{granularity}(1)")

node = self._select(f"""SELECT {",".join(exprs)} FROM events""")
node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))

assert len(node.select) == len(exprs)
for selected in node.select:
assert selected.type == ast.DateTimeType(nullable=False)

def test_recording_button_tag(self):
node: ast.SelectQuery = self._select("select <RecordingButton sessionId={'12345'} />")
node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))
Expand Down
Loading

0 comments on commit 23f3b19

Please sign in to comment.