diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py
index f0367da08ff9b..57b14907f94a3 100644
--- a/posthog/hogql/database/database.py
+++ b/posthog/hogql/database/database.py
@@ -159,7 +159,10 @@ def create_hogql_database(team_id: int, modifiers: Optional[HogQLQueryModifiers]
)
database.events.fields["person_id"] = ExpressionField(
name="person_id",
- expr=parse_expr("ifNull(override.override_person_id, event_person_id)", start=None),
+ expr=parse_expr(
+ "ifNull(nullIf(override.override_person_id, '00000000-0000-0000-0000-000000000000'), event_person_id)",
+ start=None,
+ ),
)
database.events.fields["poe"].fields["id"] = database.events.fields["person_id"]
database.events.fields["person"] = FieldTraverser(chain=["poe"])
diff --git a/posthog/hogql/database/schema/test/test_event_sessions.py b/posthog/hogql/database/schema/test/test_event_sessions.py
index 95752668f3a2f..5973b3530fba1 100644
--- a/posthog/hogql/database/schema/test/test_event_sessions.py
+++ b/posthog/hogql/database/schema/test/test_event_sessions.py
@@ -19,7 +19,7 @@ def setUp(self):
def _select(self, query: str) -> ast.SelectQuery:
select_query = cast(ast.SelectQuery, clone_expr(parse_select(query), clear_locations=True))
- return cast(ast.SelectQuery, resolve_types(select_query, self.context))
+ return cast(ast.SelectQuery, resolve_types(select_query, self.context, dialect="clickhouse"))
def _compare_operators(self, query: ast.SelectQuery, table_name: str) -> List[ast.Expr]:
assert query.where is not None and query.type is not None
@@ -143,7 +143,7 @@ def setUp(self):
def _select(self, query: str) -> ast.SelectQuery:
select_query = cast(ast.SelectQuery, clone_expr(parse_select(query), clear_locations=True))
- return cast(ast.SelectQuery, resolve_types(select_query, self.context))
+ return cast(ast.SelectQuery, resolve_types(select_query, self.context, dialect="clickhouse"))
def _clean(self, table_name: str, query: ast.SelectQuery, expr: ast.Expr) -> ast.Expr:
assert query.type is not None
diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py
index ab588239b012d..74501d11b429f 100644
--- a/posthog/hogql/printer.py
+++ b/posthog/hogql/printer.py
@@ -96,15 +96,15 @@ def prepare_ast_for_printing(
context.database = context.database or create_hogql_database(context.team_id, context.modifiers)
with context.timings.measure("resolve_types"):
- node = resolve_types(node, context, scopes=[node.type for node in stack] if stack else None)
+ node = resolve_types(node, context, dialect=dialect, scopes=[node.type for node in stack] if stack else None)
if context.modifiers.inCohortVia == "leftjoin":
with context.timings.measure("resolve_in_cohorts"):
- resolve_in_cohorts(node, stack, context)
+ resolve_in_cohorts(node, dialect, stack, context)
if dialect == "clickhouse":
with context.timings.measure("resolve_property_types"):
node = resolve_property_types(node, context)
with context.timings.measure("resolve_lazy_tables"):
- resolve_lazy_tables(node, stack, context)
+ resolve_lazy_tables(node, dialect, stack, context)
# We support global query settings, and local subquery settings.
# If the global query is a select query with settings, merge the two.
diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py
index 0f254a1ed39b8..abfd3561e4892 100644
--- a/posthog/hogql/resolver.py
+++ b/posthog/hogql/resolver.py
@@ -1,5 +1,5 @@
from datetime import date, datetime
-from typing import List, Optional, Any, cast
+from typing import List, Optional, Any, cast, Literal
from uuid import UUID
from posthog.hogql import ast
@@ -56,20 +56,27 @@ def resolve_constant_data_type(constant: Any) -> ConstantType:
def resolve_types(
node: ast.Expr,
context: HogQLContext,
+ dialect: Literal["hogql", "clickhouse"],
scopes: Optional[List[ast.SelectQueryType]] = None,
) -> ast.Expr:
- return Resolver(scopes=scopes, context=context).visit(node)
+ return Resolver(scopes=scopes, context=context, dialect=dialect).visit(node)
class Resolver(CloningVisitor):
"""The Resolver visits an AST and 1) resolves all fields, 2) assigns types to nodes, 3) expands all CTEs."""
- def __init__(self, context: HogQLContext, scopes: Optional[List[ast.SelectQueryType]] = None):
+ def __init__(
+ self,
+ context: HogQLContext,
+ dialect: Literal["hogql", "clickhouse"] = "clickhouse",
+ scopes: Optional[List[ast.SelectQueryType]] = None,
+ ):
super().__init__()
# Each SELECT query creates a new scope (type). Store all of them in a list as we traverse the tree.
self.scopes: List[ast.SelectQueryType] = scopes or []
self.current_view_depth: int = 0
self.context = context
+ self.dialect = dialect
self.database = context.database
self.cte_counter = 0
@@ -461,10 +468,12 @@ def visit_field(self, node: ast.Field):
node.type = loop_type
if isinstance(node.type, ast.ExpressionFieldType):
- new_expr = clone_expr(node.type.expr)
- new_node = ast.Alias(alias=node.type.name, expr=new_expr, hidden=True)
- new_node = self.visit(new_node)
- return new_node
+ # only swap out expression fields in ClickHouse
+ if self.dialect == "clickhouse":
+ new_expr = clone_expr(node.type.expr)
+ new_node = ast.Alias(alias=node.type.name, expr=new_expr, hidden=True)
+ new_node = self.visit(new_node)
+ return new_node
if isinstance(node.type, ast.FieldType) and node.start is not None and node.end is not None:
self.context.add_notice(
diff --git a/posthog/hogql/test/test_metadata.py b/posthog/hogql/test/test_metadata.py
index 5b1a7da4e3e62..3bbfba2bb204a 100644
--- a/posthog/hogql/test/test_metadata.py
+++ b/posthog/hogql/test/test_metadata.py
@@ -2,6 +2,7 @@
from posthog.models import PropertyDefinition, Cohort
from posthog.schema import HogQLMetadata, HogQLMetadataResponse
from posthog.test.base import APIBaseTest, ClickhouseTestMixin
+from django.test import override_settings
class TestMetadata(ClickhouseTestMixin, APIBaseTest):
@@ -135,6 +136,7 @@ def test_metadata_table(self):
metadata = self._expr("is_identified", "persons")
self.assertEqual(metadata.isValid, True)
+ @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False)
def test_metadata_in_cohort(self):
cohort = Cohort.objects.create(team=self.team, name="cohort_name")
query = (
diff --git a/posthog/hogql/test/test_modifiers.py b/posthog/hogql/test/test_modifiers.py
index 744ad4f063fdb..9e253c30042af 100644
--- a/posthog/hogql/test/test_modifiers.py
+++ b/posthog/hogql/test/test_modifiers.py
@@ -70,9 +70,9 @@ def test_modifiers_persons_on_events_mode_mapping(self):
(
PersonsOnEventsMode.v2_enabled,
"events.event",
- "ifNull(events__override.override_person_id, events.person_id) AS id",
+ "ifNull(nullIf(events__override.override_person_id, %(hogql_val_0)s), events.person_id) AS id",
"events.person_properties",
- "toTimeZone(events.person_created_at, %(hogql_val_0)s) AS created_at",
+ "toTimeZone(events.person_created_at, %(hogql_val_1)s) AS created_at",
),
]
diff --git a/posthog/hogql/test/test_printer.py b/posthog/hogql/test/test_printer.py
index 9d63479c0f5a2..2195971a317f8 100644
--- a/posthog/hogql/test/test_printer.py
+++ b/posthog/hogql/test/test_printer.py
@@ -774,7 +774,7 @@ def test_select_sample(self):
f"AS persons SAMPLE 0.1 ON equals(persons.id, events__pdi.person_id) WHERE equals(events.team_id, {self.team.pk}) LIMIT 10000",
)
- with override_settings(PERSON_ON_EVENTS_OVERRIDE=True):
+ with override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False):
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
diff --git a/posthog/hogql/test/test_query.py b/posthog/hogql/test/test_query.py
index f6dc78a7e4108..7f9d9fe65e99e 100644
--- a/posthog/hogql/test/test_query.py
+++ b/posthog/hogql/test/test_query.py
@@ -397,7 +397,7 @@ def test_query_select_person_with_joins_without_poe(self):
self.assertEqual(response.results[0][3], "tim@posthog.com")
@pytest.mark.usefixtures("unittest_snapshot")
- @override_settings(PERSON_ON_EVENTS_OVERRIDE=True)
+ @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False)
def test_query_select_person_with_poe_without_joins(self):
with freeze_time("2020-01-10"):
self._create_random_events()
@@ -473,7 +473,7 @@ def test_prop_cohort_basic(self):
)
self.assertEqual(response.results, [("$pageview", 2)])
- with override_settings(PERSON_ON_EVENTS_OVERRIDE=True):
+ with override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False):
response = execute_hogql_query(
"SELECT event, count(*) FROM events WHERE {cohort_filter} GROUP BY event",
team=self.team,
diff --git a/posthog/hogql/test/test_resolver.py b/posthog/hogql/test/test_resolver.py
index c2fc2324271bb..b7afa076bb362 100644
--- a/posthog/hogql/test/test_resolver.py
+++ b/posthog/hogql/test/test_resolver.py
@@ -47,14 +47,14 @@ def setUp(self):
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_events_table(self):
expr = self._select("SELECT event, events.timestamp FROM events WHERE events.event = 'test'")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
def test_will_not_run_twice(self):
expr = self._select("SELECT event, events.timestamp FROM events WHERE events.event = 'test'")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
with self.assertRaises(ResolverException) as context:
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
self.assertEqual(
str(context.exception),
"Type already resolved for SelectQuery (SelectQueryType). Can't run again.",
@@ -63,19 +63,19 @@ def test_will_not_run_twice(self):
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_events_table_alias(self):
expr = self._select("SELECT event, e.timestamp FROM events e WHERE e.event = 'test'")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_events_table_column_alias(self):
expr = self._select("SELECT event as ee, ee, ee as e, e.timestamp FROM events e WHERE e.event = 'test'")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_events_table_column_alias_inside_subquery(self):
expr = self._select("SELECT b FROM (select event as b, timestamp as c from events) e WHERE e.b = 'test'")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
def test_resolve_subquery_no_field_access(self):
@@ -84,7 +84,7 @@ def test_resolve_subquery_no_field_access(self):
"SELECT event, (select count() from events where event = e.event) as c FROM events e where event = '$pageview'"
)
with self.assertRaises(ResolverException) as e:
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
self.assertEqual(str(e.exception), "Unable to resolve field: e")
@pytest.mark.usefixtures("unittest_snapshot")
@@ -101,13 +101,13 @@ def test_resolve_constant_type(self):
"tuple": ast.Constant(value=(1, 2, 3)),
},
)
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_boolean_operation_types(self):
expr = self._select("SELECT 1 and 1, 1 or 1, not true")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
def test_resolve_errors(self):
@@ -120,55 +120,55 @@ def test_resolve_errors(self):
]
for query in queries:
with self.assertRaises(ResolverException) as e:
- resolve_types(self._select(query), self.context)
+ resolve_types(self._select(query), self.context, dialect="clickhouse")
self.assertIn("Unable to resolve field:", str(e.exception))
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_lazy_pdi_person_table(self):
expr = self._select("select distinct_id, person.id from person_distinct_ids")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_lazy_events_pdi_table(self):
expr = self._select("select event, pdi.person_id from events")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_lazy_events_pdi_table_aliased(self):
expr = self._select("select event, e.pdi.person_id from events e")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_lazy_events_pdi_person_table(self):
expr = self._select("select event, pdi.person.id from events")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_lazy_events_pdi_person_table_aliased(self):
expr = self._select("select event, e.pdi.person.id from events e")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_virtual_events_poe(self):
expr = self._select("select event, poe.id from events")
- expr = resolve_types(expr, self.context)
+ expr = resolve_types(expr, self.context, dialect="clickhouse")
assert pretty_dataclasses(expr) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_union_all(self):
node = self._select("select event, timestamp from events union all select event, timestamp from events")
- node = resolve_types(node, self.context)
+ node = resolve_types(node, self.context, dialect="clickhouse")
assert pretty_dataclasses(node) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_call_type(self):
node = self._select("select max(timestamp) from events")
- node = resolve_types(node, self.context)
+ node = resolve_types(node, self.context, dialect="clickhouse")
assert pretty_dataclasses(node) == self.snapshot
def test_ctes_loop(self):
@@ -235,7 +235,7 @@ def test_ctes_subquery_recursion(self):
def test_asterisk_expander_table(self):
self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False
node = self._select("select * from events")
- node = resolve_types(node, self.context)
+ node = resolve_types(node, self.context, dialect="clickhouse")
assert pretty_dataclasses(node) == self.snapshot
@override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False)
@@ -243,19 +243,19 @@ def test_asterisk_expander_table(self):
def test_asterisk_expander_table_alias(self):
self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False
node = self._select("select * from events e")
- node = resolve_types(node, self.context)
+ node = resolve_types(node, self.context, dialect="clickhouse")
assert pretty_dataclasses(node) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_asterisk_expander_subquery(self):
node = self._select("select * from (select 1 as a, 2 as b)")
- node = resolve_types(node, self.context)
+ node = resolve_types(node, self.context, dialect="clickhouse")
assert pretty_dataclasses(node) == self.snapshot
@pytest.mark.usefixtures("unittest_snapshot")
def test_asterisk_expander_subquery_alias(self):
node = self._select("select x.* from (select 1 as a, 2 as b) x")
- node = resolve_types(node, self.context)
+ node = resolve_types(node, self.context, dialect="clickhouse")
assert pretty_dataclasses(node) == self.snapshot
@override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False)
@@ -263,13 +263,13 @@ def test_asterisk_expander_subquery_alias(self):
def test_asterisk_expander_from_subquery_table(self):
self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False
node = self._select("select * from (select * from events)")
- node = resolve_types(node, self.context)
+ node = resolve_types(node, self.context, dialect="clickhouse")
assert pretty_dataclasses(node) == self.snapshot
def test_asterisk_expander_multiple_table_error(self):
node = self._select("select * from (select 1 as a, 2 as b) x left join (select 1 as a, 2 as b) y on x.a = y.a")
with self.assertRaises(ResolverException) as e:
- resolve_types(node, self.context)
+ resolve_types(node, self.context, dialect="clickhouse")
self.assertEqual(
str(e.exception),
"Cannot use '*' without table name when there are multiple tables in the query",
@@ -280,13 +280,13 @@ def test_asterisk_expander_multiple_table_error(self):
def test_asterisk_expander_select_union(self):
self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False
node = self._select("select * from (select * from events union all select * from events)")
- node = resolve_types(node, self.context)
+ node = resolve_types(node, self.context, dialect="clickhouse")
assert pretty_dataclasses(node) == self.snapshot
def test_lambda_parent_scope(self):
# does not raise
node = self._select("select timestamp, arrayMap(x -> x + timestamp, [2]) from events")
- node = cast(ast.SelectQuery, resolve_types(node, self.context))
+ node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))
# found a type
lambda_type: ast.SelectQueryType = cast(ast.SelectQueryType, cast(ast.Call, node.select[1]).args[0].type)
@@ -304,7 +304,7 @@ def test_field_traverser_double_dot(self):
self.database.events.fields["poe"].fields["properties"] = StringJSONDatabaseField(name="person_properties")
node = self._select("SELECT event, person.id, person.properties, person.created_at FROM events")
- node = cast(ast.SelectQuery, resolve_types(node, self.context))
+ node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))
# all columns resolve to a type in the end
assert cast(ast.FieldType, node.select[0].type).resolve_database_field() == StringDatabaseField(
@@ -322,7 +322,7 @@ def test_field_traverser_double_dot(self):
def test_visit_hogqlx_tag(self):
node = self._select("select event from ")
- node = cast(ast.SelectQuery, resolve_types(node, self.context))
+ node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))
table_node = cast(ast.SelectQuery, node).select_from.table
expected = ast.SelectQuery(
select=[ast.Alias(hidden=True, alias="event", expr=ast.Field(chain=["event"]))],
@@ -332,7 +332,7 @@ def test_visit_hogqlx_tag(self):
def test_visit_hogqlx_tag_alias(self):
node = self._select("select event from a")
- node = cast(ast.SelectQuery, resolve_types(node, self.context))
+ node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))
assert cast(ast.SelectQuery, node).select_from.alias == "a"
def test_visit_hogqlx_tag_source(self):
@@ -346,7 +346,7 @@ def test_visit_hogqlx_tag_source(self):
/>
)
"""
- node = cast(ast.SelectQuery, resolve_types(self._select(query), self.context))
+ node = cast(ast.SelectQuery, resolve_types(self._select(query), self.context, dialect="hogql"))
hogql = print_prepared_ast(node, HogQLContext(team_id=self.team.pk, enable_select_queries=True), "hogql")
expected = (
f"SELECT id, email FROM "
diff --git a/posthog/hogql/transforms/in_cohort.py b/posthog/hogql/transforms/in_cohort.py
index 670d0a8e73c2a..a565391e309f3 100644
--- a/posthog/hogql/transforms/in_cohort.py
+++ b/posthog/hogql/transforms/in_cohort.py
@@ -1,4 +1,4 @@
-from typing import List, Optional, cast
+from typing import List, Optional, cast, Literal
from posthog.hogql import ast
from posthog.hogql.context import HogQLContext
@@ -11,21 +11,24 @@
def resolve_in_cohorts(
node: ast.Expr,
+ dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[ast.SelectQuery]] = None,
context: HogQLContext = None,
):
- InCohortResolver(stack=stack, context=context).visit(node)
+ InCohortResolver(stack=stack, dialect=dialect, context=context).visit(node)
class InCohortResolver(TraversingVisitor):
def __init__(
self,
+ dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[ast.SelectQuery]] = None,
context: HogQLContext = None,
):
super().__init__()
self.stack: List[ast.SelectQuery] = stack or []
self.context = context
+ self.dialect = dialect
def visit_select_query(self, node: ast.SelectQuery):
self.stack.append(node)
@@ -130,11 +133,12 @@ def _add_join_for_cohort(
)
new_join = cast(
ast.JoinExpr,
- resolve_types(new_join, self.context, [self.stack[-1].type]),
+ resolve_types(new_join, self.context, self.dialect, [self.stack[-1].type]),
)
new_join.constraint.expr.left = resolve_types(
ast.Field(chain=[f"in_cohort__{cohort_id}", "person_id"]),
self.context,
+ self.dialect,
[self.stack[-1].type],
)
new_join.constraint.expr.right = clone_expr(compare.left)
@@ -147,6 +151,7 @@ def _add_join_for_cohort(
compare.left = resolve_types(
ast.Field(chain=[f"in_cohort__{cohort_id}", "matched"]),
self.context,
+ self.dialect,
[self.stack[-1].type],
)
- compare.right = resolve_types(ast.Constant(value=1), self.context, [self.stack[-1].type])
+ compare.right = resolve_types(ast.Constant(value=1), self.context, self.dialect, [self.stack[-1].type])
diff --git a/posthog/hogql/transforms/lazy_tables.py b/posthog/hogql/transforms/lazy_tables.py
index b2a9a7d12bf4d..4734fed012c91 100644
--- a/posthog/hogql/transforms/lazy_tables.py
+++ b/posthog/hogql/transforms/lazy_tables.py
@@ -1,5 +1,5 @@
import dataclasses
-from typing import Dict, List, Optional, cast
+from typing import Dict, List, Optional, cast, Literal
from posthog.hogql import ast
from posthog.hogql.context import HogQLContext
@@ -12,10 +12,11 @@
def resolve_lazy_tables(
node: ast.Expr,
+ dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[ast.SelectQuery]] = None,
context: HogQLContext = None,
):
- LazyTableResolver(stack=stack, context=context).visit(node)
+ LazyTableResolver(stack=stack, context=context, dialect=dialect).visit(node)
@dataclasses.dataclass
@@ -35,12 +36,14 @@ class TableToAdd:
class LazyTableResolver(TraversingVisitor):
def __init__(
self,
+ dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[ast.SelectQuery]] = None,
context: HogQLContext = None,
):
super().__init__()
self.stack_of_fields: List[List[ast.FieldType | ast.PropertyType]] = [[]] if stack else []
self.context = context
+ self.dialect = dialect
def visit_property_type(self, node: ast.PropertyType):
if node.joined_subquery is not None:
@@ -181,7 +184,7 @@ def visit_select_query(self, node: ast.SelectQuery):
for table_name, table_to_add in tables_to_add.items():
subquery = table_to_add.lazy_table.lazy_select(table_to_add.fields_accessed, self.context.modifiers)
subquery = cast(ast.SelectQuery, clone_expr(subquery, clear_locations=True))
- subquery = cast(ast.SelectQuery, resolve_types(subquery, self.context, [node.type]))
+ subquery = cast(ast.SelectQuery, resolve_types(subquery, self.context, self.dialect, [node.type]))
old_table_type = select_type.tables[table_name]
select_type.tables[table_name] = ast.SelectQueryAliasType(alias=table_name, select_query_type=subquery.type)
@@ -204,7 +207,7 @@ def visit_select_query(self, node: ast.SelectQuery):
node,
)
join_to_add = cast(ast.JoinExpr, clone_expr(join_to_add, clear_locations=True))
- join_to_add = cast(ast.JoinExpr, resolve_types(join_to_add, self.context, [node.type]))
+ join_to_add = cast(ast.JoinExpr, resolve_types(join_to_add, self.context, self.dialect, [node.type]))
select_type.tables[to_table] = join_to_add.type
diff --git a/posthog/warehouse/models/datawarehouse_saved_query.py b/posthog/warehouse/models/datawarehouse_saved_query.py
index 9117fa7c4eaf0..f52ee05cb6926 100644
--- a/posthog/warehouse/models/datawarehouse_saved_query.py
+++ b/posthog/warehouse/models/datawarehouse_saved_query.py
@@ -71,7 +71,7 @@ def s3_tables(self):
node = parse_select(self.query["query"])
context.database = create_hogql_database(context.team_id)
- node = resolve_types(node, context)
+ node = resolve_types(node, context, dialect="clickhouse")
table_collector = S3TableVisitor()
table_collector.visit(node)