Skip to content

Commit

Permalink
Cover C++ unquote_string with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Twixes committed Oct 9, 2023
1 parent 9c7e56c commit f47e72f
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 91 deletions.
83 changes: 50 additions & 33 deletions hogql_parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
return NULL; \
} \
string err_what = e.what(); \
PyObject* py_err = PyObject_CallObject(error_type, Py_BuildValue("(s#)", err_what.c_str(), err_what.size())); \
PyObject* py_err = PyObject_CallObject(error_type, Py_BuildValue("(s#)", err_what.data(), err_what.size())); \
if (!py_err) { \
Py_DECREF(error_type); \
return NULL; \
Expand Down Expand Up @@ -53,7 +53,7 @@ PyObject* X_PyList_FromStrings(const vector<string>& items) {
return NULL;
}
for (size_t i = 0; i < items.size(); i++) {
PyObject* value = PyUnicode_FromStringAndSize(items[i].c_str(), items[i].size());
PyObject* value = PyUnicode_FromStringAndSize(items[i].data(), items[i].size());
if (!value) {
Py_DECREF(list);
return NULL;
Expand Down Expand Up @@ -396,7 +396,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
if (join_op_ctx) {
string join_op = visitAsString(join_op_ctx);
join_op.append(" JOIN");
PyObject_SetAttrString(join2, "join_type", PyUnicode_FromStringAndSize(join_op.c_str(), join_op.size()));
PyObject_SetAttrString(join2, "join_type", PyUnicode_FromStringAndSize(join_op.data(), join_op.size()));
} else {
PyObject_SetAttrString(join2, "join_type", PyUnicode_FromString("JOIN"));
}
Expand Down Expand Up @@ -623,7 +623,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
} else if (ctx->identifier()) {
alias = visitAsString(ctx->identifier());
} else if (ctx->STRING_LITERAL()) {
alias = parse_string_literal(ctx->STRING_LITERAL());
alias = unquote_string_terminal(ctx->STRING_LITERAL());
} else {
throw HogQLParsingException("A ColumnExprAlias must have the alias in some form");
}
Expand All @@ -635,7 +635,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
throw HogQLSyntaxException("\"" + alias + "\" cannot be an alias or identifier, as it's a reserved keyword");
}

return build_ast_node("Alias", "{s:N,s:s#}", "expr", expr, "alias", alias.c_str(), alias.size());
return build_ast_node("Alias", "{s:N,s:s#}", "expr", expr, "alias", alias.data(), alias.size());
}

VISIT_UNSUPPORTED(ColumnExprExtract)
Expand Down Expand Up @@ -834,7 +834,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
VISIT(ColumnExprPropertyAccess) {
PyObject* object = visitAsPyObject(ctx->columnExpr());
string identifier = visitAsString(ctx->identifier());
PyObject* property = build_ast_node("Constant", "{s:s#}", "value", identifier.c_str(), identifier.size());
PyObject* property = build_ast_node("Constant", "{s:s#}", "value", identifier.data(), identifier.size());
return build_ast_node("ArrayAccess", "{s:N,s:N}", "array", object, "property", property);
}

Expand Down Expand Up @@ -936,8 +936,8 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
string over_identifier = visitAsString(ctx->identifier(1));
PyObject* args = visitAsPyObjectOrEmptyList(column_expr_list_ctx);
return build_ast_node(
"WindowFunction", "{s:s#,s:N,s:s#}", "name", name.c_str(), name.size(), "args", args, "over_identifier",
over_identifier.c_str(), over_identifier.size()
"WindowFunction", "{s:s#,s:N,s:s#}", "name", name.data(), name.size(), "args", args, "over_identifier",
over_identifier.data(), over_identifier.size()

);
}
Expand All @@ -948,7 +948,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
PyObject* args = visitAsPyObjectOrEmptyList(column_expr_list_ctx);
PyObject* over_expr = visitAsPyObjectOrNone(ctx->windowExpr());
return build_ast_node(
"WindowFunction", "{s:s#,s:N,s:N}", "name", identifier.c_str(), identifier.size(), "args", args, "over_expr",
"WindowFunction", "{s:s#,s:N,s:N}", "name", identifier.data(), identifier.size(), "args", args, "over_expr",
over_expr
);
}
Expand All @@ -962,7 +962,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
PyObject* args = visitAsPyObjectOrEmptyList(column_arg_list_ctx);
PyObject* distinct = ctx->DISTINCT() ? Py_True : Py_False;
return build_ast_node(
"Call", "{s:s#,s:N,s:N,s:O}", "name", name.c_str(), name.size(), "params", parameters, "args", args, "distinct",
"Call", "{s:s#,s:N,s:N,s:O}", "name", name.data(), name.size(), "params", parameters, "args", args, "distinct",
distinct
);
}
Expand Down Expand Up @@ -1001,23 +1001,23 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
PyObject* subquery = visitAsPyObject(ctx->selectUnionStmt());
string name = visitAsString(ctx->identifier());
return build_ast_node(
"CTE", "{s:s#,s:N,s:s}", "name", name.c_str(), name.size(), "expr", subquery, "cte_type", "subquery"
"CTE", "{s:s#,s:N,s:s}", "name", name.data(), name.size(), "expr", subquery, "cte_type", "subquery"
);
}

VISIT(WithExprColumn) {
PyObject* expr = visitAsPyObject(ctx->columnExpr());
string name = visitAsString(ctx->identifier());
return build_ast_node(
"CTE", "{s:s#,s:N,s:s}", "name", name.c_str(), name.size(), "expr", expr, "cte_type", "column"
"CTE", "{s:s#,s:N,s:s}", "name", name.data(), name.size(), "expr", expr, "cte_type", "column"
);
}

VISIT(ColumnIdentifier) {
auto placeholder_ctx = ctx->PLACEHOLDER();
if (placeholder_ctx) {
string placeholder = parse_string_literal(placeholder_ctx);
return build_ast_node("Placeholder", "{s:s#}", "field", placeholder.c_str(), placeholder.size());
string placeholder = unquote_string_terminal(placeholder_ctx);
return build_ast_node("Placeholder", "{s:s#}", "field", placeholder.data(), placeholder.size());
}

auto table_identifier_ctx = ctx->tableIdentifier();
Expand Down Expand Up @@ -1053,8 +1053,8 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
VISIT(TableExprSubquery) { return visit(ctx->selectUnionStmt()); }

VISIT(TableExprPlaceholder) {
string placeholder = parse_string_literal(ctx->PLACEHOLDER());
return build_ast_node("Placeholder", "{s:s#}", "field", placeholder.c_str(), placeholder.size());
string placeholder = unquote_string_terminal(ctx->PLACEHOLDER());
return build_ast_node("Placeholder", "{s:s#}", "field", placeholder.data(), placeholder.size());
}

VISIT(TableExprAlias) {
Expand All @@ -1065,7 +1065,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
throw HogQLSyntaxException("ALIAS is a reserved keyword");
}
PyObject* table = visitAsPyObject(ctx->tableExpr());
PyObject* py_alias = PyUnicode_FromStringAndSize(alias.c_str(), alias.size());
PyObject* py_alias = PyUnicode_FromStringAndSize(alias.data(), alias.size());
if (is_ast_node_instance(table, "JoinExpr")) {
PyObject_SetAttrString(table, "alias", py_alias);
return table;
Expand All @@ -1085,7 +1085,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
table_args = Py_NewRef(Py_None);
}
return build_ast_node(
"JoinExpr", "{s:N,s:N}", "table", build_ast_node("Field", "{s:[s#]}", "chain", name.c_str(), name.size()),
"JoinExpr", "{s:N,s:N}", "table", build_ast_node("Field", "{s:[s#]}", "chain", name.data(), name.size()),
"table_args", table_args
);
}
Expand All @@ -1112,7 +1112,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
PyObject* result;
if (text.find(".") != string::npos || text.find("e") != string::npos || !text.compare("-inf") ||
!text.compare("inf") || !text.compare("nan")) {
PyObject* pyText = PyUnicode_FromStringAndSize(text.c_str(), text.size());
PyObject* pyText = PyUnicode_FromStringAndSize(text.data(), text.size());
value = PyFloat_FromString(pyText);
result = build_ast_node("Constant", "{s:N}", "value", value);
Py_DECREF(pyText);
Expand All @@ -1130,8 +1130,8 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
}
auto string_literal_terminal = ctx->STRING_LITERAL();
if (string_literal_terminal) {
string text = parse_string_literal(string_literal_terminal);
return build_ast_node("Constant", "{s:s#}", "value", text.c_str(), text.size());
string text = unquote_string_terminal(string_literal_terminal);
return build_ast_node("Constant", "{s:s#}", "value", text.data(), text.size());
}
return visitChildren(ctx);
}
Expand All @@ -1148,7 +1148,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
char first_char = text.front();
char last_char = text.back();
if ((first_char == '`' && last_char == '`') || (first_char == '"' && last_char == '"')) {
return parse_string(text);
return unquote_string(text);
}
}
return text;
Expand All @@ -1160,7 +1160,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
char first_char = text.front();
char last_char = text.back();
if ((first_char == '`' && last_char == '`') || (first_char == '"' && last_char == '"')) {
return parse_string(text);
return unquote_string(text);
}
}
return text;
Expand Down Expand Up @@ -1226,7 +1226,8 @@ parser_state* get_module_state(PyObject* module) {

// MODULE METHODS

static PyObject* parse_expr(PyObject* self, PyObject* args, PyObject* kwargs) {
static PyObject* method_parse_expr(PyObject* self, PyObject* args, PyObject* kwargs) {
parser_state* state = get_module_state(self);
const char* str;
int start; // TODO: Determine if this `start` kwarg of `parse_expr` is needed for anything

Expand All @@ -1239,7 +1240,6 @@ static PyObject* parse_expr(PyObject* self, PyObject* args, PyObject* kwargs) {
HogQLParser parser = get_parser(str);
parser.removeErrorListeners();
parser.addErrorListener(new HogQLErrorListener(str));
parser_state* state = get_module_state(self);
HogQLParser::ExprContext* parse_tree;
try {
parse_tree = parser.expr();
Expand All @@ -1248,15 +1248,15 @@ static PyObject* parse_expr(PyObject* self, PyObject* args, PyObject* kwargs) {
return converter.visitAsPyObjectFinal(parse_tree);
}

static PyObject* parse_order_expr(PyObject* self, PyObject* args) {
static PyObject* method_parse_order_expr(PyObject* self, PyObject* args) {
parser_state* state = get_module_state(self);
const char* str;
if (!PyArg_ParseTuple(args, "s", &str)) {
return NULL;
}
HogQLParser parser = get_parser(str);
parser.removeErrorListeners();
parser.addErrorListener(new HogQLErrorListener(str));
parser_state* state = get_module_state(self);
HogQLParser::OrderExprContext* parse_tree;
try {
parse_tree = parser.orderExpr();
Expand All @@ -1265,15 +1265,15 @@ static PyObject* parse_order_expr(PyObject* self, PyObject* args) {
return converter.visitAsPyObjectFinal(parse_tree);
}

static PyObject* parse_select(PyObject* self, PyObject* args) {
static PyObject* method_parse_select(PyObject* self, PyObject* args) {
parser_state* state = get_module_state(self);
const char* str;
if (!PyArg_ParseTuple(args, "s", &str)) {
return NULL;
}
HogQLParser parser = get_parser(str);
parser.removeErrorListeners();
parser.addErrorListener(new HogQLErrorListener(str));
parser_state* state = get_module_state(self);
HogQLParser::SelectContext* parse_tree;
try {
parse_tree = parser.select();
Expand All @@ -1282,23 +1282,40 @@ static PyObject* parse_select(PyObject* self, PyObject* args) {
return converter.visitAsPyObjectFinal(parse_tree);
}

static PyObject* method_unquote_string(PyObject* self, PyObject* args) {
parser_state* state = get_module_state(self);
const char* str;
if (!PyArg_ParseTuple(args, "s", &str)) {
return NULL;
}
string unquoted_string;
try {
unquoted_string = unquote_string(str);
} catch HANDLE_HOGQL_EXCEPTION(SyntaxException);
return PyUnicode_FromStringAndSize(unquoted_string.data(), unquoted_string.size());
}

// MODULE SETUP

static PyMethodDef parser_methods[] = {
{.ml_name = "parse_expr",
// The cast of the function is necessary since PyCFunction values only take two
// PyObject* parameters, and parse_expr() takes three.
.ml_meth = (PyCFunction)(void (*)(void))parse_expr,
// PyObject* parameters, and method_parse_expr() takes three.
.ml_meth = (PyCFunction)(void (*)(void))method_parse_expr,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Parse the HogQL expression string into an AST"},
{.ml_name = "parse_order_expr",
.ml_meth = parse_order_expr,
.ml_meth = method_parse_order_expr,
.ml_flags = METH_VARARGS,
.ml_doc = "Parse the ORDER BY clause string into an AST"},
{.ml_name = "parse_select",
.ml_meth = parse_select,
.ml_meth = method_parse_select,
.ml_flags = METH_VARARGS,
.ml_doc = "Parse the HogQL SELECT statement string into an AST"},
{.ml_name = "unquote_string",
.ml_meth = method_unquote_string,
.ml_flags = METH_VARARGS,
.ml_doc = "Unquote the string (an identifier or a string literal))"},
{NULL, NULL, 0, NULL}};

static int parser_modexec(PyObject* module) {
Expand Down
2 changes: 2 additions & 0 deletions hogql_parser/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,7 @@
"Operating System :: POSIX :: Linux",
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
)
16 changes: 8 additions & 8 deletions hogql_parser/string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

using namespace std;

// TODO: Cover with tests

string parse_string(string text) {
string unquote_string(string text) {
size_t original_text_size = text.size();
if (original_text_size == 0) {
throw HogQLParsingException("Encountered an unexpected empty string input");
Expand All @@ -33,23 +31,25 @@ string parse_string(string text) {
} else {
throw HogQLSyntaxException("Invalid string literal, must start and end with the same quote type: " + text);
}

// Copied from clickhouse_driver/util/escape.py
boost::replace_all(text, "\\a", "\a");
boost::replace_all(text, "\\b", "\b");
boost::replace_all(text, "\\f", "\f");
boost::replace_all(text, "\\r", "\r");
boost::replace_all(text, "\\n", "\n");
boost::replace_all(text, "\\r", "\r");
boost::replace_all(text, "\\t", "\t");
boost::replace_all(text, "\\0", "\0");
boost::replace_all(text, "\\a", "\a");
boost::replace_all(text, "\\v", "\v");
boost::replace_all(text, "\\0", "\0");
boost::replace_all(text, "\\\\", "\\");

return text;
}

string parse_string_literal(antlr4::tree::TerminalNode* node) {
string unquote_string_terminal(antlr4::tree::TerminalNode* node) {
string text = node->getText();
try {
return parse_string(text);
return unquote_string(text);
} catch (HogQLException& e) {
throw HogQLSyntaxException(e.what(), node->getSymbol()->getStartIndex(), node->getSymbol()->getStopIndex() + 1);
}
Expand Down
4 changes: 2 additions & 2 deletions hogql_parser/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

#include "antlr4-runtime.h"

std::string parse_string(std::string text);
std::string unquote_string(std::string text);

std::string parse_string_literal(antlr4::tree::TerminalNode* node);
std::string unquote_string_terminal(antlr4::tree::TerminalNode* node);
6 changes: 3 additions & 3 deletions posthog/hogql/parse_string.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from antlr4 import ParserRuleContext

from posthog.hogql.errors import HogQLException
from posthog.hogql.errors import SyntaxException


def parse_string(text: str) -> str:
Expand All @@ -22,15 +22,15 @@ def parse_string(text: str) -> str:
text = text.replace("{{", "{")
text = text.replace("\\{", "{")
else:
raise HogQLException(f"Invalid string literal, must start and end with the same quote type: {text}")
raise SyntaxException(f"Invalid string literal, must start and end with the same quote type: {text}")

# copied from clickhouse_driver/util/escape.py
text = text.replace("\\b", "\b")
text = text.replace("\\f", "\f")
text = text.replace("\\r", "\r")
text = text.replace("\\n", "\n")
text = text.replace("\\t", "\t")
text = text.replace("\\0", "\0")
text = text.replace("\\0", "") # Null characters are ignored
text = text.replace("\\a", "\a")
text = text.replace("\\v", "\v")
text = text.replace("\\\\", "\\")
Expand Down
Loading

0 comments on commit f47e72f

Please sign in to comment.