Skip to content

Commit

Permalink
expression fields
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra committed Nov 20, 2023
1 parent 9bcb5c3 commit e944e91
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 175 deletions.
12 changes: 12 additions & 0 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FieldOrTable,
DatabaseField,
StringArrayDatabaseField,
ExpressionField,
)
from posthog.hogql.errors import HogQLException, NotImplementedException

Expand Down Expand Up @@ -60,6 +61,8 @@ def get_child(self, name: str) -> Type:
return FieldTraverserType(table_type=self, chain=field.chain)
if isinstance(field, VirtualTable):
return VirtualTableType(table_type=self, field=name, virtual_table=field)
if isinstance(field, ExpressionField):
return ExpressionFieldType(table_type=self, name=name, expr=field.expr)
return FieldType(name=name, table_type=self)
raise HogQLException(f"Field not found: {name}")

Expand Down Expand Up @@ -121,6 +124,8 @@ class SelectQueryType(Type):

# all aliases a select query has access to in its scope
aliases: Dict[str, FieldAliasType] = field(default_factory=dict)
# # fields that may be converted to expressions
# expression_fields: Dict[str, "ExpressionFieldType"] = field(default_factory=dict)
# all types a select query exports
columns: Dict[str, Type] = field(default_factory=dict)
# all from and join, tables and subqueries with aliases
Expand Down Expand Up @@ -274,6 +279,13 @@ class FieldTraverserType(Type):
table_type: TableOrSelectType


@dataclass(kw_only=True)
class ExpressionFieldType(Type):
name: str
expr: Expr
table_type: TableOrSelectType


@dataclass(kw_only=True)
class FieldType(Type):
name: str
Expand Down
22 changes: 17 additions & 5 deletions posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DateDatabaseField,
FloatDatabaseField,
FunctionCallTable,
ExpressionField,
)
from posthog.hogql.database.schema.log_entries import (
LogEntriesTable,
Expand All @@ -35,13 +36,15 @@
from posthog.hogql.database.schema.person_overrides import (
PersonOverridesTable,
RawPersonOverridesTable,
join_with_person_overrides_table,
)
from posthog.hogql.database.schema.session_replay_events import (
RawSessionReplayEventsTable,
SessionReplayEventsTable,
)
from posthog.hogql.database.schema.static_cohort_people import StaticCohortPeople
from posthog.hogql.errors import HogQLException
from posthog.hogql.parser import parse_expr
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.models.team.team import WeekStartDay
from posthog.schema import HogQLQueryModifiers, PersonsOnEventsMode
Expand Down Expand Up @@ -143,14 +146,23 @@ def create_hogql_database(team_id: int, modifiers: Optional[HogQLQueryModifiers]
database.events.fields["poe"].fields["created_at"] = FieldTraverser(chain=["..", "pdi", "person", "created_at"])
database.events.fields["poe"].fields["properties"] = StringJSONDatabaseField(name="person_properties")

elif (
modifiers.personsOnEventsMode == PersonsOnEventsMode.v1_enabled
or modifiers.personsOnEventsMode == PersonsOnEventsMode.v2_enabled
):
# TODO: split PoE v1 and v2 once SQL Expression fields are supported #15180
elif modifiers.personsOnEventsMode == PersonsOnEventsMode.v1_enabled:
database.events.fields["person"] = FieldTraverser(chain=["poe"])
database.events.fields["person_id"] = StringDatabaseField(name="person_id")

elif modifiers.personsOnEventsMode == PersonsOnEventsMode.v2_enabled:
database.events.fields["old_person_id"] = StringDatabaseField(name="person_id")
database.events.fields["override"] = LazyJoin(
from_field="old_person_id",
join_table=PersonOverridesTable(),
join_function=join_with_person_overrides_table,
)
database.events.fields["person_id"] = ExpressionField(
name="person_id",
expr=parse_expr("override_person_id != null ? override_person_id : old_person_id", start=None),
)
database.events.fields["person"] = FieldTraverser(chain=["poe"])

for mapping in GroupTypeMapping.objects.filter(team=team):
if database.events.fields.get(mapping.group_type) is None:
database.events.fields[mapping.group_type] = FieldTraverser(chain=[f"group_{mapping.group_type_index}"])
Expand Down
5 changes: 5 additions & 0 deletions posthog/hogql/database/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from pydantic import ConfigDict, BaseModel

from posthog.hogql.base import Expr
from posthog.hogql.errors import HogQLException, NotImplementedException
from posthog.schema import HogQLQueryModifiers

Expand Down Expand Up @@ -57,6 +58,10 @@ class BooleanDatabaseField(DatabaseField):
pass


class ExpressionField(DatabaseField):
expr: Expr


class FieldTraverser(FieldOrTable):
model_config = ConfigDict(extra="forbid")

Expand Down
32 changes: 31 additions & 1 deletion posthog/hogql/database/test/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from parameterized import parameterized

from posthog.hogql.database.database import create_hogql_database, serialize_database
from posthog.hogql.database.models import FieldTraverser, StringDatabaseField
from posthog.hogql.database.models import FieldTraverser, StringDatabaseField, ExpressionField
from posthog.hogql.modifiers import create_default_modifiers_for_team
from posthog.hogql.parser import parse_expr, parse_select
from posthog.hogql.printer import print_ast
from posthog.hogql.context import HogQLContext
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.test.base import BaseTest
from posthog.warehouse.models import DataWarehouseTable, DataWarehouseCredential
Expand Down Expand Up @@ -80,3 +84,29 @@ def test_database_group_type_mappings_overwrite(self):
db = create_hogql_database(team_id=self.team.pk)

assert db.events.fields["event"] == StringDatabaseField(name="event")

def test_database_expression_fields(self):
db = create_hogql_database(team_id=self.team.pk)
db.numbers.fields["expression"] = ExpressionField(name="expression", expr=parse_expr("1 + 1"))
db.numbers.fields["double"] = ExpressionField(name="double", expr=parse_expr("number * 2"))
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
database=db,
modifiers=create_default_modifiers_for_team(self.team),
)

sql = "select number, double, expression + number from numbers(2)"
query = print_ast(parse_select(sql), context, dialect="clickhouse")
assert (
query
== "SELECT numbers.number, multiply(numbers.number, 2), plus(plus(1, 1), numbers.number) FROM numbers(2) AS numbers LIMIT 10000"
), query

# sql = "select double from (select double from numbers(2))"
# query = print_ast(parse_select(sql), context, dialect="clickhouse")
# assert query == "SELECT numbers.number, multiply(numbers.number, 2), plus(plus(1, 1), numbers.number) FROM numbers(2) AS numbers LIMIT 10000", query
#
# sql = "select double from (select * from numbers(2))"
# query = print_ast(parse_select(sql), context, dialect="clickhouse")
# assert query == "SELECT numbers.number, multiply(numbers.number, 2), plus(plus(1, 1), numbers.number) FROM numbers(2) AS numbers LIMIT 10000", query
6 changes: 5 additions & 1 deletion posthog/hogql/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def visit_select_query(self, node: ast.SelectQuery):
self._expand_asterisk_columns(new_node, new_expr.type)
continue

# not an asterisk
# Any alias we can use to refer to this field?
alias = None
if isinstance(new_expr.type, ast.FieldAliasType):
alias = new_expr.type.alias
elif isinstance(new_expr.type, ast.FieldType):
Expand Down Expand Up @@ -466,6 +467,9 @@ def visit_field(self, node: ast.Field):
raise ResolverException(f"Cannot resolve type {'.'.join(node.chain)}. Unable to resolve {next_chain}.")
node.type = loop_type

if isinstance(node.type, ast.ExpressionFieldType):
return self.visit(ast.Alias(alias=node.type.name, expr=node.type.expr, hidden=True))

if isinstance(node.type, ast.FieldType) and node.start is not None and node.end is not None:
self.context.add_notice(
start=node.start,
Expand Down
12 changes: 6 additions & 6 deletions posthog/hogql/test/__snapshots__/test_printer.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
SELECT
groupArray(start_of_period) AS date,
groupArray(counts) AS total,
status
status AS status
FROM
(SELECT
if(equals(status, 'dormant'), negate(sum(counts)), negate(negate(sum(counts)))) AS counts,
start_of_period,
status
start_of_period AS start_of_period,
status AS status
FROM
(SELECT
periods.start_of_period AS start_of_period,
0 AS counts,
status
status AS status
FROM
(SELECT
minus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(number)) AS start_of_period
Expand All @@ -29,9 +29,9 @@
start_of_period ASC
UNION ALL
SELECT
start_of_period,
start_of_period AS start_of_period,
count(DISTINCT person_id) AS counts,
status
status AS status
FROM
(SELECT
events.person.id AS person_id,
Expand Down
Loading

0 comments on commit e944e91

Please sign in to comment.