Skip to content

Commit

Permalink
move parse string to its own home
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra committed Feb 7, 2023
1 parent fc4380e commit 27a0af0
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
18 changes: 14 additions & 4 deletions posthog/hogql/parser_utils.py → posthog/hogql/parse_string.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
def parse_string_literal(ctx):
"""Converts a string literal received from antlr via ctx.getText() into a Python string"""
text = ctx.getText()
from antlr4 import ParserRuleContext


def parse_string(text: str) -> str:
"""Converts a string received from antlr via ctx.getText() into a Python string"""
if text.startswith("'") and text.endswith("'"):
text = text[1:-1]
text = text.replace("''", "'")
text = text.replace("\\'", "'")
elif text.startswith('"') and text.endswith('"'):
text = text[1:-1]
text = text.replace('""', '"')
text = text.replace('\\"', '"')
elif text.startswith("`") and text.endswith("`"):
text = text[1:-1]
text = text.replace("``", "`")
text = text.replace("\\`", "`")
else:
raise ValueError(f"Invalid string literal, must start and end with the same quote symbol: {text}")

Expand All @@ -22,7 +27,12 @@ def parse_string_literal(ctx):
text = text.replace("\\0", "\0")
text = text.replace("\\a", "\a")
text = text.replace("\\v", "\v")
text = text.replace("\\'", "'")
text = text.replace("\\\\", "\\")

return text


def parse_string_literal(ctx: ParserRuleContext) -> str:
"""Converts a STRING_LITERAL received from antlr via ctx.getText() into a Python string"""
text = ctx.getText()
return parse_string(text)
6 changes: 3 additions & 3 deletions posthog/hogql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from posthog.hogql import ast
from posthog.hogql.grammar.HogQLLexer import HogQLLexer
from posthog.hogql.grammar.HogQLParser import HogQLParser
from posthog.hogql.parser_utils import parse_string_literal
from posthog.hogql.parse_string import parse_string, parse_string_literal


def parse_expr(expr: str) -> ast.Expr:
Expand Down Expand Up @@ -470,13 +470,13 @@ def visitKeywordForAlias(self, ctx: HogQLParser.KeywordForAliasContext):
def visitAlias(self, ctx: HogQLParser.AliasContext):
text = ctx.getText()
if len(text) >= 2 and text.startswith("`") and text.endswith("`"):
text = parse_string_literal(ctx)
text = parse_string(ctx)
return text

def visitIdentifier(self, ctx: HogQLParser.IdentifierContext):
text = ctx.getText()
if len(text) >= 2 and text.startswith("`") and text.endswith("`"):
text = parse_string_literal(ctx)
text = parse_string(ctx)
return text

def visitIdentifierOrNull(self, ctx: HogQLParser.IdentifierOrNullContext):
Expand Down
41 changes: 41 additions & 0 deletions posthog/hogql/test/test_parse_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from posthog.hogql.parse_string import parse_string
from posthog.test.base import BaseTest


class TestParseString(BaseTest):
def test_quote_types(self):
self.assertEqual(parse_string("`asd`"), "asd")
self.assertEqual(parse_string("'asd'"), "asd")
self.assertEqual(parse_string('"asd"'), "asd")

def test_escaped_quotes(self):
self.assertEqual(parse_string("`a``sd`"), "a`sd")
self.assertEqual(parse_string("'a''sd'"), "a'sd")
self.assertEqual(parse_string('"a""sd"'), 'a"sd')

def test_escaped_quotes_slash(self):
self.assertEqual(parse_string("`a\\`sd`"), "a`sd")
self.assertEqual(parse_string("'a\\'sd'"), "a'sd")
self.assertEqual(parse_string('"a\\"sd"'), 'a"sd')

def test_slash_escape(self):
self.assertEqual(parse_string("`a\nsd`"), "a\nsd")
self.assertEqual(parse_string("`a\\bsd`"), "a\bsd")
self.assertEqual(parse_string("`a\\fsd`"), "a\fsd")
self.assertEqual(parse_string("`a\\rsd`"), "a\rsd")
self.assertEqual(parse_string("`a\\nsd`"), "a\nsd")
self.assertEqual(parse_string("`a\\tsd`"), "a\tsd")
self.assertEqual(parse_string("`a\\0sd`"), "a\0sd")
self.assertEqual(parse_string("`a\\asd`"), "a\asd")
self.assertEqual(parse_string("`a\\vsd`"), "a\vsd")
self.assertEqual(parse_string("`a\\\\sd`"), "a\\sd")

def test_slash_escape_not_escaped(self):
self.assertEqual(parse_string("`a\\xsd`"), "a\\xsd")
self.assertEqual(parse_string("`a\\ysd`"), "a\\ysd")
self.assertEqual(parse_string("`a\\osd`"), "a\\osd")

def test_slash_escape_slash_multiple(self):
self.assertEqual(parse_string("`a\\\\nsd`"), "a\\\nsd")
self.assertEqual(parse_string("`a\\\\n\\sd`"), "a\\\n\\sd")
self.assertEqual(parse_string("`a\\\\n\\\\tsd`"), "a\\\n\\\tsd")

0 comments on commit 27a0af0

Please sign in to comment.