diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index f0367da08ff9b..57b14907f94a3 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -159,7 +159,10 @@ def create_hogql_database(team_id: int, modifiers: Optional[HogQLQueryModifiers] ) database.events.fields["person_id"] = ExpressionField( name="person_id", - expr=parse_expr("ifNull(override.override_person_id, event_person_id)", start=None), + expr=parse_expr( + "ifNull(nullIf(override.override_person_id, '00000000-0000-0000-0000-000000000000'), event_person_id)", + start=None, + ), ) database.events.fields["poe"].fields["id"] = database.events.fields["person_id"] database.events.fields["person"] = FieldTraverser(chain=["poe"]) diff --git a/posthog/hogql/database/schema/test/test_event_sessions.py b/posthog/hogql/database/schema/test/test_event_sessions.py index 95752668f3a2f..5973b3530fba1 100644 --- a/posthog/hogql/database/schema/test/test_event_sessions.py +++ b/posthog/hogql/database/schema/test/test_event_sessions.py @@ -19,7 +19,7 @@ def setUp(self): def _select(self, query: str) -> ast.SelectQuery: select_query = cast(ast.SelectQuery, clone_expr(parse_select(query), clear_locations=True)) - return cast(ast.SelectQuery, resolve_types(select_query, self.context)) + return cast(ast.SelectQuery, resolve_types(select_query, self.context, dialect="clickhouse")) def _compare_operators(self, query: ast.SelectQuery, table_name: str) -> List[ast.Expr]: assert query.where is not None and query.type is not None @@ -143,7 +143,7 @@ def setUp(self): def _select(self, query: str) -> ast.SelectQuery: select_query = cast(ast.SelectQuery, clone_expr(parse_select(query), clear_locations=True)) - return cast(ast.SelectQuery, resolve_types(select_query, self.context)) + return cast(ast.SelectQuery, resolve_types(select_query, self.context, dialect="clickhouse")) def _clean(self, table_name: str, query: ast.SelectQuery, expr: ast.Expr) -> ast.Expr: assert query.type is not None diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index ab588239b012d..74501d11b429f 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -96,15 +96,15 @@ def prepare_ast_for_printing( context.database = context.database or create_hogql_database(context.team_id, context.modifiers) with context.timings.measure("resolve_types"): - node = resolve_types(node, context, scopes=[node.type for node in stack] if stack else None) + node = resolve_types(node, context, dialect=dialect, scopes=[node.type for node in stack] if stack else None) if context.modifiers.inCohortVia == "leftjoin": with context.timings.measure("resolve_in_cohorts"): - resolve_in_cohorts(node, stack, context) + resolve_in_cohorts(node, dialect, stack, context) if dialect == "clickhouse": with context.timings.measure("resolve_property_types"): node = resolve_property_types(node, context) with context.timings.measure("resolve_lazy_tables"): - resolve_lazy_tables(node, stack, context) + resolve_lazy_tables(node, dialect, stack, context) # We support global query settings, and local subquery settings. # If the global query is a select query with settings, merge the two. diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py index 0f254a1ed39b8..abfd3561e4892 100644 --- a/posthog/hogql/resolver.py +++ b/posthog/hogql/resolver.py @@ -1,5 +1,5 @@ from datetime import date, datetime -from typing import List, Optional, Any, cast +from typing import List, Optional, Any, cast, Literal from uuid import UUID from posthog.hogql import ast @@ -56,20 +56,27 @@ def resolve_constant_data_type(constant: Any) -> ConstantType: def resolve_types( node: ast.Expr, context: HogQLContext, + dialect: Literal["hogql", "clickhouse"], scopes: Optional[List[ast.SelectQueryType]] = None, ) -> ast.Expr: - return Resolver(scopes=scopes, context=context).visit(node) + return Resolver(scopes=scopes, context=context, dialect=dialect).visit(node) class Resolver(CloningVisitor): """The Resolver visits an AST and 1) resolves all fields, 2) assigns types to nodes, 3) expands all CTEs.""" - def __init__(self, context: HogQLContext, scopes: Optional[List[ast.SelectQueryType]] = None): + def __init__( + self, + context: HogQLContext, + dialect: Literal["hogql", "clickhouse"] = "clickhouse", + scopes: Optional[List[ast.SelectQueryType]] = None, + ): super().__init__() # Each SELECT query creates a new scope (type). Store all of them in a list as we traverse the tree. self.scopes: List[ast.SelectQueryType] = scopes or [] self.current_view_depth: int = 0 self.context = context + self.dialect = dialect self.database = context.database self.cte_counter = 0 @@ -461,10 +468,12 @@ def visit_field(self, node: ast.Field): node.type = loop_type if isinstance(node.type, ast.ExpressionFieldType): - new_expr = clone_expr(node.type.expr) - new_node = ast.Alias(alias=node.type.name, expr=new_expr, hidden=True) - new_node = self.visit(new_node) - return new_node + # only swap out expression fields in ClickHouse + if self.dialect == "clickhouse": + new_expr = clone_expr(node.type.expr) + new_node = ast.Alias(alias=node.type.name, expr=new_expr, hidden=True) + new_node = self.visit(new_node) + return new_node if isinstance(node.type, ast.FieldType) and node.start is not None and node.end is not None: self.context.add_notice( diff --git a/posthog/hogql/test/test_metadata.py b/posthog/hogql/test/test_metadata.py index 5b1a7da4e3e62..3bbfba2bb204a 100644 --- a/posthog/hogql/test/test_metadata.py +++ b/posthog/hogql/test/test_metadata.py @@ -2,6 +2,7 @@ from posthog.models import PropertyDefinition, Cohort from posthog.schema import HogQLMetadata, HogQLMetadataResponse from posthog.test.base import APIBaseTest, ClickhouseTestMixin +from django.test import override_settings class TestMetadata(ClickhouseTestMixin, APIBaseTest): @@ -135,6 +136,7 @@ def test_metadata_table(self): metadata = self._expr("is_identified", "persons") self.assertEqual(metadata.isValid, True) + @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False) def test_metadata_in_cohort(self): cohort = Cohort.objects.create(team=self.team, name="cohort_name") query = ( diff --git a/posthog/hogql/test/test_modifiers.py b/posthog/hogql/test/test_modifiers.py index 744ad4f063fdb..9e253c30042af 100644 --- a/posthog/hogql/test/test_modifiers.py +++ b/posthog/hogql/test/test_modifiers.py @@ -70,9 +70,9 @@ def test_modifiers_persons_on_events_mode_mapping(self): ( PersonsOnEventsMode.v2_enabled, "events.event", - "ifNull(events__override.override_person_id, events.person_id) AS id", + "ifNull(nullIf(events__override.override_person_id, %(hogql_val_0)s), events.person_id) AS id", "events.person_properties", - "toTimeZone(events.person_created_at, %(hogql_val_0)s) AS created_at", + "toTimeZone(events.person_created_at, %(hogql_val_1)s) AS created_at", ), ] diff --git a/posthog/hogql/test/test_printer.py b/posthog/hogql/test/test_printer.py index 9d63479c0f5a2..2195971a317f8 100644 --- a/posthog/hogql/test/test_printer.py +++ b/posthog/hogql/test/test_printer.py @@ -774,7 +774,7 @@ def test_select_sample(self): f"AS persons SAMPLE 0.1 ON equals(persons.id, events__pdi.person_id) WHERE equals(events.team_id, {self.team.pk}) LIMIT 10000", ) - with override_settings(PERSON_ON_EVENTS_OVERRIDE=True): + with override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False): context = HogQLContext( team_id=self.team.pk, enable_select_queries=True, diff --git a/posthog/hogql/test/test_query.py b/posthog/hogql/test/test_query.py index f6dc78a7e4108..7f9d9fe65e99e 100644 --- a/posthog/hogql/test/test_query.py +++ b/posthog/hogql/test/test_query.py @@ -397,7 +397,7 @@ def test_query_select_person_with_joins_without_poe(self): self.assertEqual(response.results[0][3], "tim@posthog.com") @pytest.mark.usefixtures("unittest_snapshot") - @override_settings(PERSON_ON_EVENTS_OVERRIDE=True) + @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False) def test_query_select_person_with_poe_without_joins(self): with freeze_time("2020-01-10"): self._create_random_events() @@ -473,7 +473,7 @@ def test_prop_cohort_basic(self): ) self.assertEqual(response.results, [("$pageview", 2)]) - with override_settings(PERSON_ON_EVENTS_OVERRIDE=True): + with override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False): response = execute_hogql_query( "SELECT event, count(*) FROM events WHERE {cohort_filter} GROUP BY event", team=self.team, diff --git a/posthog/hogql/test/test_resolver.py b/posthog/hogql/test/test_resolver.py index c2fc2324271bb..b7afa076bb362 100644 --- a/posthog/hogql/test/test_resolver.py +++ b/posthog/hogql/test/test_resolver.py @@ -47,14 +47,14 @@ def setUp(self): @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_events_table(self): expr = self._select("SELECT event, events.timestamp FROM events WHERE events.event = 'test'") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot def test_will_not_run_twice(self): expr = self._select("SELECT event, events.timestamp FROM events WHERE events.event = 'test'") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") with self.assertRaises(ResolverException) as context: - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") self.assertEqual( str(context.exception), "Type already resolved for SelectQuery (SelectQueryType). Can't run again.", @@ -63,19 +63,19 @@ def test_will_not_run_twice(self): @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_events_table_alias(self): expr = self._select("SELECT event, e.timestamp FROM events e WHERE e.event = 'test'") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_events_table_column_alias(self): expr = self._select("SELECT event as ee, ee, ee as e, e.timestamp FROM events e WHERE e.event = 'test'") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_events_table_column_alias_inside_subquery(self): expr = self._select("SELECT b FROM (select event as b, timestamp as c from events) e WHERE e.b = 'test'") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot def test_resolve_subquery_no_field_access(self): @@ -84,7 +84,7 @@ def test_resolve_subquery_no_field_access(self): "SELECT event, (select count() from events where event = e.event) as c FROM events e where event = '$pageview'" ) with self.assertRaises(ResolverException) as e: - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") self.assertEqual(str(e.exception), "Unable to resolve field: e") @pytest.mark.usefixtures("unittest_snapshot") @@ -101,13 +101,13 @@ def test_resolve_constant_type(self): "tuple": ast.Constant(value=(1, 2, 3)), }, ) - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_boolean_operation_types(self): expr = self._select("SELECT 1 and 1, 1 or 1, not true") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot def test_resolve_errors(self): @@ -120,55 +120,55 @@ def test_resolve_errors(self): ] for query in queries: with self.assertRaises(ResolverException) as e: - resolve_types(self._select(query), self.context) + resolve_types(self._select(query), self.context, dialect="clickhouse") self.assertIn("Unable to resolve field:", str(e.exception)) @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_lazy_pdi_person_table(self): expr = self._select("select distinct_id, person.id from person_distinct_ids") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_lazy_events_pdi_table(self): expr = self._select("select event, pdi.person_id from events") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_lazy_events_pdi_table_aliased(self): expr = self._select("select event, e.pdi.person_id from events e") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_lazy_events_pdi_person_table(self): expr = self._select("select event, pdi.person.id from events") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_lazy_events_pdi_person_table_aliased(self): expr = self._select("select event, e.pdi.person.id from events e") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_virtual_events_poe(self): expr = self._select("select event, poe.id from events") - expr = resolve_types(expr, self.context) + expr = resolve_types(expr, self.context, dialect="clickhouse") assert pretty_dataclasses(expr) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_union_all(self): node = self._select("select event, timestamp from events union all select event, timestamp from events") - node = resolve_types(node, self.context) + node = resolve_types(node, self.context, dialect="clickhouse") assert pretty_dataclasses(node) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_call_type(self): node = self._select("select max(timestamp) from events") - node = resolve_types(node, self.context) + node = resolve_types(node, self.context, dialect="clickhouse") assert pretty_dataclasses(node) == self.snapshot def test_ctes_loop(self): @@ -235,7 +235,7 @@ def test_ctes_subquery_recursion(self): def test_asterisk_expander_table(self): self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False node = self._select("select * from events") - node = resolve_types(node, self.context) + node = resolve_types(node, self.context, dialect="clickhouse") assert pretty_dataclasses(node) == self.snapshot @override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False) @@ -243,19 +243,19 @@ def test_asterisk_expander_table(self): def test_asterisk_expander_table_alias(self): self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False node = self._select("select * from events e") - node = resolve_types(node, self.context) + node = resolve_types(node, self.context, dialect="clickhouse") assert pretty_dataclasses(node) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_asterisk_expander_subquery(self): node = self._select("select * from (select 1 as a, 2 as b)") - node = resolve_types(node, self.context) + node = resolve_types(node, self.context, dialect="clickhouse") assert pretty_dataclasses(node) == self.snapshot @pytest.mark.usefixtures("unittest_snapshot") def test_asterisk_expander_subquery_alias(self): node = self._select("select x.* from (select 1 as a, 2 as b) x") - node = resolve_types(node, self.context) + node = resolve_types(node, self.context, dialect="clickhouse") assert pretty_dataclasses(node) == self.snapshot @override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False) @@ -263,13 +263,13 @@ def test_asterisk_expander_subquery_alias(self): def test_asterisk_expander_from_subquery_table(self): self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False node = self._select("select * from (select * from events)") - node = resolve_types(node, self.context) + node = resolve_types(node, self.context, dialect="clickhouse") assert pretty_dataclasses(node) == self.snapshot def test_asterisk_expander_multiple_table_error(self): node = self._select("select * from (select 1 as a, 2 as b) x left join (select 1 as a, 2 as b) y on x.a = y.a") with self.assertRaises(ResolverException) as e: - resolve_types(node, self.context) + resolve_types(node, self.context, dialect="clickhouse") self.assertEqual( str(e.exception), "Cannot use '*' without table name when there are multiple tables in the query", @@ -280,13 +280,13 @@ def test_asterisk_expander_multiple_table_error(self): def test_asterisk_expander_select_union(self): self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False node = self._select("select * from (select * from events union all select * from events)") - node = resolve_types(node, self.context) + node = resolve_types(node, self.context, dialect="clickhouse") assert pretty_dataclasses(node) == self.snapshot def test_lambda_parent_scope(self): # does not raise node = self._select("select timestamp, arrayMap(x -> x + timestamp, [2]) from events") - node = cast(ast.SelectQuery, resolve_types(node, self.context)) + node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse")) # found a type lambda_type: ast.SelectQueryType = cast(ast.SelectQueryType, cast(ast.Call, node.select[1]).args[0].type) @@ -304,7 +304,7 @@ def test_field_traverser_double_dot(self): self.database.events.fields["poe"].fields["properties"] = StringJSONDatabaseField(name="person_properties") node = self._select("SELECT event, person.id, person.properties, person.created_at FROM events") - node = cast(ast.SelectQuery, resolve_types(node, self.context)) + node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse")) # all columns resolve to a type in the end assert cast(ast.FieldType, node.select[0].type).resolve_database_field() == StringDatabaseField( @@ -322,7 +322,7 @@ def test_field_traverser_double_dot(self): def test_visit_hogqlx_tag(self): node = self._select("select event from ") - node = cast(ast.SelectQuery, resolve_types(node, self.context)) + node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse")) table_node = cast(ast.SelectQuery, node).select_from.table expected = ast.SelectQuery( select=[ast.Alias(hidden=True, alias="event", expr=ast.Field(chain=["event"]))], @@ -332,7 +332,7 @@ def test_visit_hogqlx_tag(self): def test_visit_hogqlx_tag_alias(self): node = self._select("select event from a") - node = cast(ast.SelectQuery, resolve_types(node, self.context)) + node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse")) assert cast(ast.SelectQuery, node).select_from.alias == "a" def test_visit_hogqlx_tag_source(self): @@ -346,7 +346,7 @@ def test_visit_hogqlx_tag_source(self): /> ) """ - node = cast(ast.SelectQuery, resolve_types(self._select(query), self.context)) + node = cast(ast.SelectQuery, resolve_types(self._select(query), self.context, dialect="hogql")) hogql = print_prepared_ast(node, HogQLContext(team_id=self.team.pk, enable_select_queries=True), "hogql") expected = ( f"SELECT id, email FROM " diff --git a/posthog/hogql/transforms/in_cohort.py b/posthog/hogql/transforms/in_cohort.py index 670d0a8e73c2a..a565391e309f3 100644 --- a/posthog/hogql/transforms/in_cohort.py +++ b/posthog/hogql/transforms/in_cohort.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import List, Optional, cast, Literal from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -11,21 +11,24 @@ def resolve_in_cohorts( node: ast.Expr, + dialect: Literal["hogql", "clickhouse"], stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None, ): - InCohortResolver(stack=stack, context=context).visit(node) + InCohortResolver(stack=stack, dialect=dialect, context=context).visit(node) class InCohortResolver(TraversingVisitor): def __init__( self, + dialect: Literal["hogql", "clickhouse"], stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None, ): super().__init__() self.stack: List[ast.SelectQuery] = stack or [] self.context = context + self.dialect = dialect def visit_select_query(self, node: ast.SelectQuery): self.stack.append(node) @@ -130,11 +133,12 @@ def _add_join_for_cohort( ) new_join = cast( ast.JoinExpr, - resolve_types(new_join, self.context, [self.stack[-1].type]), + resolve_types(new_join, self.context, self.dialect, [self.stack[-1].type]), ) new_join.constraint.expr.left = resolve_types( ast.Field(chain=[f"in_cohort__{cohort_id}", "person_id"]), self.context, + self.dialect, [self.stack[-1].type], ) new_join.constraint.expr.right = clone_expr(compare.left) @@ -147,6 +151,7 @@ def _add_join_for_cohort( compare.left = resolve_types( ast.Field(chain=[f"in_cohort__{cohort_id}", "matched"]), self.context, + self.dialect, [self.stack[-1].type], ) - compare.right = resolve_types(ast.Constant(value=1), self.context, [self.stack[-1].type]) + compare.right = resolve_types(ast.Constant(value=1), self.context, self.dialect, [self.stack[-1].type]) diff --git a/posthog/hogql/transforms/lazy_tables.py b/posthog/hogql/transforms/lazy_tables.py index b2a9a7d12bf4d..4734fed012c91 100644 --- a/posthog/hogql/transforms/lazy_tables.py +++ b/posthog/hogql/transforms/lazy_tables.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, List, Optional, cast +from typing import Dict, List, Optional, cast, Literal from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -12,10 +12,11 @@ def resolve_lazy_tables( node: ast.Expr, + dialect: Literal["hogql", "clickhouse"], stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None, ): - LazyTableResolver(stack=stack, context=context).visit(node) + LazyTableResolver(stack=stack, context=context, dialect=dialect).visit(node) @dataclasses.dataclass @@ -35,12 +36,14 @@ class TableToAdd: class LazyTableResolver(TraversingVisitor): def __init__( self, + dialect: Literal["hogql", "clickhouse"], stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None, ): super().__init__() self.stack_of_fields: List[List[ast.FieldType | ast.PropertyType]] = [[]] if stack else [] self.context = context + self.dialect = dialect def visit_property_type(self, node: ast.PropertyType): if node.joined_subquery is not None: @@ -181,7 +184,7 @@ def visit_select_query(self, node: ast.SelectQuery): for table_name, table_to_add in tables_to_add.items(): subquery = table_to_add.lazy_table.lazy_select(table_to_add.fields_accessed, self.context.modifiers) subquery = cast(ast.SelectQuery, clone_expr(subquery, clear_locations=True)) - subquery = cast(ast.SelectQuery, resolve_types(subquery, self.context, [node.type])) + subquery = cast(ast.SelectQuery, resolve_types(subquery, self.context, self.dialect, [node.type])) old_table_type = select_type.tables[table_name] select_type.tables[table_name] = ast.SelectQueryAliasType(alias=table_name, select_query_type=subquery.type) @@ -204,7 +207,7 @@ def visit_select_query(self, node: ast.SelectQuery): node, ) join_to_add = cast(ast.JoinExpr, clone_expr(join_to_add, clear_locations=True)) - join_to_add = cast(ast.JoinExpr, resolve_types(join_to_add, self.context, [node.type])) + join_to_add = cast(ast.JoinExpr, resolve_types(join_to_add, self.context, self.dialect, [node.type])) select_type.tables[to_table] = join_to_add.type diff --git a/posthog/warehouse/models/datawarehouse_saved_query.py b/posthog/warehouse/models/datawarehouse_saved_query.py index 9117fa7c4eaf0..f52ee05cb6926 100644 --- a/posthog/warehouse/models/datawarehouse_saved_query.py +++ b/posthog/warehouse/models/datawarehouse_saved_query.py @@ -71,7 +71,7 @@ def s3_tables(self): node = parse_select(self.query["query"]) context.database = create_hogql_database(context.team_id) - node = resolve_types(node, context) + node = resolve_types(node, context, dialect="clickhouse") table_collector = S3TableVisitor() table_collector.visit(node)