From 85a878c147d2ef3ec39984f4ae56e0e425279973 Mon Sep 17 00:00:00 2001 From: Robbie Date: Thu, 21 Mar 2024 20:36:02 +0000 Subject: [PATCH] More robust way of looking up tables --- posthog/hogql/ast.py | 8 ++++++ .../util/session_where_clause_extractor.py | 26 +++++++++++++------ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/posthog/hogql/ast.py b/posthog/hogql/ast.py index a459514f2524f0..70d6163934de60 100644 --- a/posthog/hogql/ast.py +++ b/posthog/hogql/ast.py @@ -48,6 +48,11 @@ def resolve_database_field(self, context: HogQLContext): return self.type.resolve_database_field(context) raise NotImplementedException("FieldAliasType.resolve_database_field not implemented") + def resolve_table_type(self, context: HogQLContext): + if isinstance(self.type, FieldType): + return self.type.table_type + raise NotImplementedException("FieldAliasType.resolve_table_type not implemented") + @dataclass(kw_only=True) class BaseTableType(Type): @@ -339,6 +344,9 @@ def get_child(self, name: str | int, context: HogQLContext) -> Type: f'Can not access property "{name}" on field "{self.name}" of type: {type(database_field).__name__}' ) + def resolve_table_type(self, context: HogQLContext): + return self.table_type + @dataclass(kw_only=True) class PropertyType(Type): diff --git a/posthog/hogql/database/schema/util/session_where_clause_extractor.py b/posthog/hogql/database/schema/util/session_where_clause_extractor.py index 3eed4b5b86f3d1..cc02cfeb8b78c2 100644 --- a/posthog/hogql/database/schema/util/session_where_clause_extractor.py +++ b/posthog/hogql/database/schema/util/session_where_clause_extractor.py @@ -5,6 +5,7 @@ from posthog.hogql.ast import CompareOperationOp, ArithmeticOperationOp from posthog.hogql.context import HogQLContext from posthog.hogql.database.models import DatabaseField + from posthog.hogql.visitor import clone_expr, CloningVisitor, Visitor SESSION_BUFFER_DAYS = 3 @@ -288,7 +289,7 @@ def visit_constant(self, node: ast.Constant) -> bool: def visit_field(self, node: ast.Field) -> bool: if node.type and isinstance(node.type, ast.FieldType): resolved_field = node.type.resolve_database_field(self.context) - if resolved_field and isinstance(resolved_field, DatabaseField): + if resolved_field and isinstance(resolved_field, DatabaseField) and resolved_field: return resolved_field.name in ["min_timestamp", "timestamp"] # no type information, so just use the name of the field return node.chain[-1] in ["min_timestamp", "timestamp"] @@ -344,9 +345,15 @@ def visit_placeholder(self, node: ast.Placeholder) -> bool: raise Exception() def visit_alias(self, node: ast.Alias) -> bool: + from posthog.hogql.database.schema.events import EventsTable + from posthog.hogql.database.schema.sessions import SessionsTable + if node.type and isinstance(node.type, ast.FieldAliasType): resolved_field = node.type.resolve_database_field(self.context) - return resolved_field.name in ["min_timestamp", "timestamp"] + table = node.type.resolve_table_type(self.context).table + return (isinstance(table, EventsTable) and resolved_field.name == "timestamp") or ( + isinstance(table, SessionsTable) and resolved_field.name == "min_timestamp" + ) return self.visit(node.expr) @@ -363,14 +370,17 @@ def __init__(self, context: HogQLContext, *args, **kwargs): self.context = context def visit_field(self, node: ast.Field) -> ast.Field: + from posthog.hogql.database.schema.events import EventsTable + from posthog.hogql.database.schema.sessions import SessionsTable + if node.type and isinstance(node.type, ast.FieldType): resolved_field = node.type.resolve_database_field(self.context) - if ( - resolved_field - and isinstance(resolved_field, DatabaseField) - and resolved_field.name in ["min_timestamp", "timestamp"] - ): - return ast.Field(chain=["raw_sessions", "min_timestamp"]) + table = node.type.resolve_table_type(self.context).table + if resolved_field and isinstance(resolved_field, DatabaseField): + if (isinstance(table, EventsTable) and resolved_field.name == "timestamp") or ( + isinstance(table, SessionsTable) and resolved_field.name == "min_timestamp" + ): + return ast.Field(chain=["raw_sessions", "min_timestamp"]) # no type information, so just use the name of the field if node.chain[-1] in ["min_timestamp", "timestamp"]: return ast.Field(chain=["raw_sessions", "min_timestamp"])