From 975ae54d64c394df75a6bbae52d1fb36fbcaa891 Mon Sep 17 00:00:00 2001 From: Robbie Date: Mon, 11 Mar 2024 16:27:49 +0000 Subject: [PATCH] WIP --- posthog/hogql/database/__init__.py | 0 posthog/hogql/database/models.py | 2 +- posthog/hogql/database/schema/__init__.py | 0 posthog/hogql/database/schema/sessions.py | 8 +- .../hogql/database/schema/util/__init__.py | 0 .../util/session_where_clause_extractor.py | 105 ++++++++++++++++ .../test_session_where_clause_extractor.py | 118 ++++++++++++++++++ .../schema/util/where_clause_visitor.py | 72 +++++++++++ posthog/hogql/transforms/lazy_tables.py | 2 +- 9 files changed, 302 insertions(+), 5 deletions(-) create mode 100644 posthog/hogql/database/__init__.py create mode 100644 posthog/hogql/database/schema/__init__.py create mode 100644 posthog/hogql/database/schema/util/__init__.py create mode 100644 posthog/hogql/database/schema/util/session_where_clause_extractor.py create mode 100644 posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py create mode 100644 posthog/hogql/database/schema/util/where_clause_visitor.py diff --git a/posthog/hogql/database/__init__.py b/posthog/hogql/database/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/posthog/hogql/database/models.py b/posthog/hogql/database/models.py index e95a26614bed88..2bae10ec6e6497 100644 --- a/posthog/hogql/database/models.py +++ b/posthog/hogql/database/models.py @@ -134,7 +134,7 @@ class LazyTable(Table): model_config = ConfigDict(extra="forbid") - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers) -> Any: + def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers, node) -> Any: raise NotImplementedException("LazyTable.lazy_select not overridden") diff --git a/posthog/hogql/database/schema/__init__.py b/posthog/hogql/database/schema/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/posthog/hogql/database/schema/sessions.py b/posthog/hogql/database/schema/sessions.py index 2a4865798eeb86..c06244102ebc32 100644 --- a/posthog/hogql/database/schema/sessions.py +++ b/posthog/hogql/database/schema/sessions.py @@ -1,5 +1,6 @@ from typing import Dict, List, cast +from posthog.hogql import ast from posthog.hogql.database.models import ( StringDatabaseField, DateTimeDatabaseField, @@ -62,7 +63,7 @@ def avoid_asterisk_fields(self) -> List[str]: ] -def select_from_sessions_table(requested_fields: Dict[str, List[str | int]]): +def select_from_sessions_table(requested_fields: Dict[str, List[str | int]], node: ast.SelectQuery): from posthog.hogql import ast table_name = "raw_sessions" @@ -148,8 +149,9 @@ class SessionsTable(LazyTable): "channel_type": StringDatabaseField(name="channel_type"), } - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): - return select_from_sessions_table(requested_fields) + + def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers, node: ast.SelectQuery): + return select_from_sessions_table(requested_fields, node) def to_printed_clickhouse(self, context): return "sessions" diff --git a/posthog/hogql/database/schema/util/__init__.py b/posthog/hogql/database/schema/util/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/posthog/hogql/database/schema/util/session_where_clause_extractor.py b/posthog/hogql/database/schema/util/session_where_clause_extractor.py new file mode 100644 index 00000000000000..e1105c01c36405 --- /dev/null +++ b/posthog/hogql/database/schema/util/session_where_clause_extractor.py @@ -0,0 +1,105 @@ +from typing import Union + +from posthog.hogql import ast +from posthog.hogql.ast import CompareOperationOp +from posthog.hogql.database.schema.util.where_clause_visitor import ASTVisitor + +SESSION_BUFFER_DAYS = 3 + + +class AbortOptimisationException(Exception): + pass + + +class SessionWhereClauseExtractor(ASTVisitor): + def get_inner_where(self, parsed_query: ast.SelectQuery) -> Union[ast.Expr, None]: + if not parsed_query.where: + return None + + # visit the where clause + where = self.visit(parsed_query.where) + + if isinstance(where, ast.Constant): + return None + + return where + + def visit(self, node: ast.Expr) -> ast.Expr: + try: + return super().visit(node) + except AbortOptimisationException: + return ast.Constant(value=True) + + def visit_not(self, node: ast.Not) -> ast.Expr: + # don't even try to handle complex logic + raise AbortOptimisationException() + + def visit_call(self, node: ast.Call) -> ast.Expr: + if node.name.lower() == "and": + return self.visit_and(ast.And(exprs=node.args)) + elif node.name.lower() == "or": + return self.visit_or(ast.Or(exprs=node.args)) + raise AbortOptimisationException() + + def visit_field(self, node: ast.Field) -> ast.Expr: + raise AbortOptimisationException() + + def visit_constant(self, node: ast.Constant) -> ast.Expr: + raise AbortOptimisationException() + + def visit_placeholder(self, node: ast.Placeholder) -> ast.Expr: + raise Exception() # this should never happen, as placeholders should be resolved before this runs + + def visit_compare_operation(self, node: ast.CompareOperation) -> ast.Expr: + # this is somewhat leaky, as it doesn't handle aliasing correctly + if isinstance(node.left, ast.Field) and ( + node.left.chain == ["min_timestamp"] or node.left.chain == ["sessions", "min_timestamp"] + ): + if node.op == CompareOperationOp.Eq: + return ast.And( + exprs=[ + ast.CompareOperation( + op=ast.CompareOperationOp.LtEq, + left=ast.ArithmeticOperation( + op=ast.ArithmeticOperationOp.Sub, + left=node.left, + right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]), + ), + right=node.right, + ), + ast.CompareOperation( + op=ast.CompareOperationOp.GtEq, + left=ast.ArithmeticOperation( + op=ast.ArithmeticOperationOp.Add, + left=node.left, + right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]), + ), + right=node.right, + ), + ] + ) + elif node.op == CompareOperationOp.Gt or node.op == CompareOperationOp.GtEq: + return ast.CompareOperation( + op=ast.CompareOperationOp.GtEq, + left=ast.ArithmeticOperation( + op=ast.ArithmeticOperationOp.Add, + left=node.left, + right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]), + ), + right=node.right, + ) + elif node.op == CompareOperationOp.Lt or node.op == CompareOperationOp.LtEq: + return ast.CompareOperation( + op=ast.CompareOperationOp.LtEq, + left=ast.ArithmeticOperation( + op=ast.ArithmeticOperationOp.Sub, + left=node.left, + right=ast.Call(name="toIntervalDay", args=[ast.Constant(value=SESSION_BUFFER_DAYS)]), + ), + right=node.right, + ) + + raise AbortOptimisationException() + + def visit_arithmetric_operation(self, node: ast.ArithmeticOperation) -> ast.Expr: + raise AbortOptimisationException() diff --git a/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py new file mode 100644 index 00000000000000..17c89af98a6d72 --- /dev/null +++ b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py @@ -0,0 +1,118 @@ +from typing import Union, Optional + +from posthog.hogql import ast +from posthog.hogql.database.schema.util.session_where_clause_extractor import SessionWhereClauseExtractor +from posthog.hogql.database.schema.util.where_clause_visitor import StripInfoVisitor +from posthog.hogql.parser import parse_select, parse_expr + + +def f(s: Union[str, ast.Expr], placeholders: Optional[dict[str, ast.Expr]] = None) -> Union[ast.Expr, None]: + if s is None: + return None + if isinstance(s, str): + expr = parse_expr(s, placeholders=placeholders) + else: + expr = s + return StripInfoVisitor.strip_info(expr) + + +class TestSessionTimestampInliner: + def test_handles_select_with_no_where_claus(self): + inliner = SessionWhereClauseExtractor() + inner_where = inliner.get_inner_where(parse_select("SELECT * FROM sessions")) + assert inner_where is None + + def test_handles_select_with_eq(self): + inliner = SessionWhereClauseExtractor() + actual = f(inliner.get_inner_where(parse_select("SELECT * FROM sessions WHERE min_timestamp = '2021-01-01'"))) + expected = f( + "((min_timestamp - toIntervalDay(3)) <= '2021-01-01') AND ((min_timestamp + toIntervalDay(3)) >= '2021-01-01')" + ) + assert expected == actual + + def test_handles_select_with_simple_gt(self): + inliner = SessionWhereClauseExtractor() + actual = f(inliner.get_inner_where(parse_select("SELECT * FROM sessions WHERE min_timestamp > '2021-01-01'"))) + expected = f("((min_timestamp + toIntervalDay(3)) >= '2021-01-01')") + assert expected == actual + + def test_handles_select_with_simple_gte(self): + inliner = SessionWhereClauseExtractor() + actual = f(inliner.get_inner_where(parse_select("SELECT * FROM sessions WHERE min_timestamp >= '2021-01-01'"))) + expected = f("((min_timestamp + toIntervalDay(3)) >= '2021-01-01')") + assert expected == actual + + def test_handles_select_with_simple_lt(self): + inliner = SessionWhereClauseExtractor() + actual = f(inliner.get_inner_where(parse_select("SELECT * FROM sessions WHERE min_timestamp < '2021-01-01'"))) + expected = f("((min_timestamp - toIntervalDay(3)) <= '2021-01-01')") + assert expected == actual + + def test_handles_select_with_simple_lte(self): + inliner = SessionWhereClauseExtractor() + actual = f(inliner.get_inner_where(parse_select("SELECT * FROM sessions WHERE min_timestamp <= '2021-01-01'"))) + expected = f("((min_timestamp - toIntervalDay(3)) <= '2021-01-01')") + assert expected == actual + + def test_select_with_placeholder(self): + inliner = SessionWhereClauseExtractor() + actual = f( + inliner.get_inner_where( + parse_select( + "SELECT * FROM sessions WHERE min_timestamp > {timestamp}", + placeholders={"timestamp": ast.Constant(value="2021-01-01")}, + ) + ) + ) + expected = f("((min_timestamp + toIntervalDay(3)) >= '2021-01-01')") + assert expected == actual + + def test_unrelated_equals(self): + inliner = SessionWhereClauseExtractor() + actual = inliner.get_inner_where( + parse_select("SELECT * FROM sessions WHERE initial_utm_campaign = initial_utm_source") + ) + assert actual is None + + def test_timestamp_and(self): + inliner = SessionWhereClauseExtractor() + actual = f( + inliner.get_inner_where( + parse_select( + "SELECT * FROM sessions WHERE and(min_timestamp >= '2021-01-01', min_timestamp <= '2021-01-03')" + ) + ) + ) + expected = f( + "((min_timestamp + toIntervalDay(3)) >= '2021-01-01') AND ((min_timestamp - toIntervalDay(3)) <= '2021-01-03')" + ) + assert expected == actual + + def test_timestamp_or(self): + inliner = SessionWhereClauseExtractor() + actual = f( + inliner.get_inner_where( + parse_select( + "SELECT * FROM sessions WHERE and(min_timestamp <= '2021-01-01', min_timestamp >= '2021-01-03')" + ) + ) + ) + expected = f( + "((min_timestamp - toIntervalDay(3)) <= '2021-01-01') AND ((min_timestamp + toIntervalDay(3)) >= '2021-01-03')" + ) + assert expected == actual + + def test_unrelated_function(self): + inliner = SessionWhereClauseExtractor() + actual = f(inliner.get_inner_where(parse_select("SELECT * FROM sessions WHERE like('a', 'b')"))) + assert actual is None + + def test_timestamp_unrelated_function(self): + inliner = SessionWhereClauseExtractor() + actual = f(inliner.get_inner_where(parse_select("SELECT * FROM sessions WHERE like(toString(timestamp), 'b')"))) + assert actual is None + + def test_timestamp_unrelated_function_timestamp(self): + inliner = SessionWhereClauseExtractor() + actual = f(inliner.get_inner_where(parse_select("SELECT * FROM sessions WHERE like(toString(timestamp), 'b')"))) + assert actual is None diff --git a/posthog/hogql/database/schema/util/where_clause_visitor.py b/posthog/hogql/database/schema/util/where_clause_visitor.py new file mode 100644 index 00000000000000..d4178c29587fb0 --- /dev/null +++ b/posthog/hogql/database/schema/util/where_clause_visitor.py @@ -0,0 +1,72 @@ +from posthog.hogql import ast + + +class ASTVisitor: + def visit(self, node: ast.Expr) -> ast.Expr: + if isinstance(node, ast.And): + return self.visit_and(node) + elif isinstance(node, ast.Or): + return self.visit_or(node) + elif isinstance(node, ast.Not): + return self.visit_not(node) + elif isinstance(node, ast.Call): + return self.visit_call(node) + elif isinstance(node, ast.Field): + return self.visit_field(node) + elif isinstance(node, ast.Constant): + return self.visit_constant(node) + elif isinstance(node, ast.CompareOperation): + return self.visit_compare_operation(node) + elif isinstance(node, ast.ArithmeticOperation): + return self.visit_arithmetric_operation(node) + elif isinstance(node, ast.Placeholder): + return self.visit_placeholder(node) + else: + raise Exception(f"Unknown node type {type(node)}") + + def visit_and(self, node: ast.And) -> ast.Expr: + return ast.And(exprs=[self.visit(node) for node in node.exprs]) + + def visit_or(self, node: ast.Or) -> ast.Expr: + return ast.Or(exprs=[self.visit(node) for node in node.exprs]) + + def visit_not(self, node: ast.Not) -> ast.Expr: + return ast.Not(expr=self.visit(node.expr)) + + def visit_call(self, node: ast.Call) -> ast.Expr: + return ast.Call(name=node.name, args=[self.visit(arg) for arg in node.args]) + + def visit_field(self, node: ast.Field) -> ast.Expr: + return node + + def visit_constant(self, node: ast.Constant) -> ast.Expr: + return node + + def visit_compare_operation(self, node: ast.CompareOperation) -> ast.Expr: + return ast.CompareOperation( + op=node.op, + left=self.visit(node.left), + right=self.visit(node.right), + ) + + def visit_arithmetric_operation(self, node: ast.ArithmeticOperation) -> ast.Expr: + return ast.ArithmeticOperation( + op=node.op, + left=self.visit(node.left), + right=self.visit(node.right), + ) + + def visit_placeholder(self, node: ast.Placeholder) -> ast.Expr: + return node + + +class StripInfoVisitor(ASTVisitor): + def visit_field(self, node: ast.Field) -> ast.Expr: + return ast.Field(chain=node.chain) + + def visit_constant(self, node: ast.Constant) -> ast.Expr: + return ast.Constant(value=node.value) + + @staticmethod + def strip_info(node: ast.Expr) -> ast.Expr: + return StripInfoVisitor().visit(node) diff --git a/posthog/hogql/transforms/lazy_tables.py b/posthog/hogql/transforms/lazy_tables.py index bdbb322d54397b..bb42b191f33e51 100644 --- a/posthog/hogql/transforms/lazy_tables.py +++ b/posthog/hogql/transforms/lazy_tables.py @@ -309,7 +309,7 @@ def create_override(table_name: str, field_chain: List[str | int]) -> None: # For all the collected tables, create the subqueries, and add them to the table. for table_name, table_to_add in tables_to_add.items(): - subquery = table_to_add.lazy_table.lazy_select(table_to_add.fields_accessed, self.context.modifiers) + subquery = table_to_add.lazy_table.lazy_select(table_to_add.fields_accessed, self.context.modifiers, node=node) subquery = cast(ast.SelectQuery, clone_expr(subquery, clear_locations=True)) subquery = cast(ast.SelectQuery, resolve_types(subquery, self.context, self.dialect, [node.type])) old_table_type = select_type.tables[table_name]