diff --git a/posthog/hogql/hogql.py b/posthog/hogql/hogql.py index dc16b3cf8e9f6..8e216b1654b46 100644 --- a/posthog/hogql/hogql.py +++ b/posthog/hogql/hogql.py @@ -6,15 +6,15 @@ def translate_hogql(query: str, context: HogQLContext, dialect: Literal["hogql", "clickhouse"] = "clickhouse") -> str: - """Translate a HogQL expression into a Clickhouse expression.""" + """Translate a HogQL expression into a Clickhouse expression. Raises if any placeholders found.""" if query == "": raise ValueError("Empty query") try: if context.select_team_id: - node = parse_select(query) + node = parse_select(query, no_placeholders=True) else: - node = parse_expr(query) + node = parse_expr(query, no_placeholders=True) except SyntaxError as err: raise ValueError(f"SyntaxError: {err.msg}") except NotImplementedError as err: diff --git a/posthog/hogql/parser.py b/posthog/hogql/parser.py index a8347a36a0035..a2646a16d6973 100644 --- a/posthog/hogql/parser.py +++ b/posthog/hogql/parser.py @@ -7,22 +7,27 @@ from posthog.hogql.grammar.HogQLLexer import HogQLLexer from posthog.hogql.grammar.HogQLParser import HogQLParser from posthog.hogql.parse_string import parse_string, parse_string_literal -from posthog.hogql.placeholders import replace_placeholders +from posthog.hogql.placeholders import assert_no_placeholders, replace_placeholders -def parse_expr(expr: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.Expr: +def parse_expr(expr: str, placeholders: Optional[Dict[str, ast.Expr]] = None, no_placeholders=False) -> ast.Expr: parse_tree = get_parser(expr).expr() node = HogQLParseTreeConverter().visit(parse_tree) if placeholders: return replace_placeholders(node, placeholders) + elif no_placeholders: + assert_no_placeholders(node) + return node -def parse_select(statement: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.Expr: +def parse_select(statement: str, placeholders: Optional[Dict[str, ast.Expr]] = None, no_placeholders=False) -> ast.Expr: parse_tree = get_parser(statement).select() node = HogQLParseTreeConverter().visit(parse_tree) if placeholders: node = replace_placeholders(node, placeholders) + elif no_placeholders: + assert_no_placeholders(node) return node diff --git a/posthog/hogql/placeholders.py b/posthog/hogql/placeholders.py index a3b543e2a0804..0399f6dc2d498 100644 --- a/posthog/hogql/placeholders.py +++ b/posthog/hogql/placeholders.py @@ -16,3 +16,12 @@ def visit_placeholder(self, node): if node.field in self.placeholders: return self.placeholders[node.field] raise ValueError(f"Placeholder '{node.field}' not found in provided dict: {', '.join(list(self.placeholders))}") + + +def assert_no_placeholders(node: ast.Expr): + AssertNoPlaceholders().visit(node) + + +class AssertNoPlaceholders(EverythingVisitor): + def visit_placeholder(self, node): + raise ValueError(f"Placeholder '{node.field}' not allowed in this context") diff --git a/posthog/hogql/test/test_placeholders.py b/posthog/hogql/test/test_placeholders.py index 4211c00238dfc..b7a8211c2ada4 100644 --- a/posthog/hogql/test/test_placeholders.py +++ b/posthog/hogql/test/test_placeholders.py @@ -1,6 +1,6 @@ from posthog.hogql import ast from posthog.hogql.parser import parse_expr -from posthog.hogql.placeholders import replace_placeholders +from posthog.hogql.placeholders import assert_no_placeholders, replace_placeholders from posthog.test.base import BaseTest @@ -17,6 +17,15 @@ def test_replace_placeholders_simple(self): ast.Constant(value="bar"), ) + def test_replace_placeholders_error(self): + expr = ast.Placeholder(field="foo") + with self.assertRaises(ValueError) as context: + replace_placeholders(expr, {}) + self.assertTrue("Placeholder 'foo' not found in provided dict:" in str(context.exception)) + with self.assertRaises(ValueError) as context: + replace_placeholders(expr, {"bar": ast.Constant(value=123)}) + self.assertTrue("Placeholder 'foo' not found in provided dict: bar" in str(context.exception)) + def test_replace_placeholders_comparison(self): expr = parse_expr("timestamp < {timestamp}") self.assertEqual( @@ -36,3 +45,9 @@ def test_replace_placeholders_comparison(self): right=ast.Constant(value=123), ), ) + + def test_assert_no_placeholders(self): + expr = ast.Placeholder(field="foo") + with self.assertRaises(ValueError) as context: + assert_no_placeholders(expr) + self.assertTrue("Placeholder 'foo' not allowed in this context" in str(context.exception))