diff --git a/hogql_parser/parser.cpp b/hogql_parser/parser.cpp index 8f90b77c6a77d..4fae49ad24a36 100644 --- a/hogql_parser/parser.cpp +++ b/hogql_parser/parser.cpp @@ -28,10 +28,15 @@ Py_DECREF(error_type); \ return NULL; \ } \ - PyObject_SetAttrString(py_err, "start", PyLong_FromSize_t(e.start)); \ - PyObject_SetAttrString(py_err, "end", PyLong_FromSize_t(e.end)); \ + PyObject* py_start = PyLong_FromSize_t(e.start); \ + PyObject* py_end = PyLong_FromSize_t(e.end); \ + PyObject_SetAttrString(py_err, "start", py_start); \ + PyObject_SetAttrString(py_err, "end", py_end); \ + Py_DECREF(py_start); \ + Py_DECREF(py_end); \ PyErr_SetObject(error_type, py_err); \ Py_DECREF(error_type); \ + Py_DECREF(py_err); \ return NULL; \ } @@ -46,6 +51,13 @@ void X_PyList_Extend(PyObject* list, PyObject* extension) { PyList_SetSlice(list, list_size, list_size + extension_size, extension); } +// Decref all elements of a vector. +void X_Py_DECREF_ALL(vector objects) { + for (PyObject* object : objects) { + Py_DECREF(object); + } +} + // Construct a Python list from a vector of strings. Return value: New reference. PyObject* X_PyList_FromStrings(const vector& items) { PyObject* list = PyList_New(items.size()); @@ -149,9 +161,12 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { if (node.has_value() && node.type() == typeid(PyObject*)) { PyObject* py_node = any_cast(node); if (py_node && is_ast_node_instance(py_node)) { - // FIXME: This is leak, because the value argument is not decref'd. Fix for all PyObject_SetAttrString calls. - PyObject_SetAttrString(py_node, "start", PyLong_FromSize_t(start)); - PyObject_SetAttrString(py_node, "end", PyLong_FromSize_t(stop + 1)); + PyObject* py_start = PyLong_FromSize_t(start); + PyObject* py_end = PyLong_FromSize_t(stop + 1); + PyObject_SetAttrString(py_node, "start", py_start); + PyObject_SetAttrString(py_node, "end", py_end); + Py_DECREF(py_start); + Py_DECREF(py_end); } } return node; @@ -173,7 +188,9 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { PyObject* visitAsPyObject(antlr4::tree::ParseTree* tree) { PyObject* cast_result = any_cast(visit(tree)); if (!cast_result) { - throw runtime_error("Rule resulted in a null PyObject pointer. A Python exception must be set at this point."); + throw HogQLParsingException( + "Rule resulted in a null PyObject pointer. A Python exception must be set at this point." + ); } return cast_result; } @@ -249,10 +266,12 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { X_PyList_Extend(flattened_queries, sub_select_queries); Py_DECREF(sub_select_queries); } else { - Py_DECREF(flattened_queries); // FIXME: Also decref all select_queries items + Py_DECREF(flattened_queries); + X_Py_DECREF_ALL(select_queries); throw HogQLParsingException("Unexpected query node type: " + string(Py_TYPE(query)->tp_name)); } - } // FIXME: Decref all select_queries items + } + X_Py_DECREF_ALL(select_queries); if (PyList_Size(flattened_queries) == 1) { PyObject* query = PyList_GET_ITEM(flattened_queries, 0); Py_INCREF(query); @@ -277,35 +296,45 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { auto window_expr_ctxs = window_clause_ctx->windowExpr(); auto identifier_ctxs = window_clause_ctx->identifier(); if (window_expr_ctxs.size() != identifier_ctxs.size()) { + Py_DECREF(select_query); throw HogQLParsingException("WindowClause must have a matching number of window exprs and identifiers"); } PyObject* window_exprs = PyDict_New(); - PyObject_SetAttrString(select_query, "window_exprs", window_exprs); for (size_t i = 0; i < window_expr_ctxs.size(); i++) { - PyDict_SetItemString( - window_exprs, visitAsString(identifier_ctxs[i]).c_str(), visitAsPyObject(window_expr_ctxs[i]) - ); + PyObject* window_expr = visitAsPyObject(window_expr_ctxs[i]); + PyDict_SetItemString(window_exprs, visitAsString(identifier_ctxs[i]).c_str(), window_expr); + Py_DECREF(window_expr); } + PyObject_SetAttrString(select_query, "window_exprs", window_exprs); + Py_DECREF(window_exprs); } auto limit_and_offset_clause_ctx = ctx->limitAndOffsetClause(); if (limit_and_offset_clause_ctx) { - PyObject_SetAttrString(select_query, "limit", visitAsPyObject(limit_and_offset_clause_ctx->columnExpr(0))); + PyObject* limit = visitAsPyObject(limit_and_offset_clause_ctx->columnExpr(0)); + PyObject_SetAttrString(select_query, "limit", limit); + Py_DECREF(limit); auto offset_ctx = limit_and_offset_clause_ctx->columnExpr(1); if (offset_ctx) { - PyObject_SetAttrString(select_query, "offset", visitAsPyObject(offset_ctx)); + PyObject* offset = visitAsPyObject(offset_ctx); + PyObject_SetAttrString(select_query, "offset", offset); + Py_DECREF(offset); } auto limit_by_exprs_ctx = limit_and_offset_clause_ctx->columnExprList(); if (limit_by_exprs_ctx) { - PyObject_SetAttrString(select_query, "limit_by", visitAsPyObject(limit_by_exprs_ctx)); + PyObject* limit_by_exprs = visitAsPyObject(limit_by_exprs_ctx); + PyObject_SetAttrString(select_query, "limit_by", limit_by_exprs); + Py_DECREF(limit_by_exprs); } if (limit_and_offset_clause_ctx->WITH() && limit_and_offset_clause_ctx->TIES()) { - PyObject_SetAttrString(select_query, "limit_with_ties", Py_NewRef(Py_True)); + PyObject_SetAttrString(select_query, "limit_with_ties", Py_True); } } else { auto offset_only_clause_ctx = ctx->offsetOnlyClause(); if (offset_only_clause_ctx) { - PyObject_SetAttrString(select_query, "offset", visitAsPyObject(offset_only_clause_ctx->columnExpr())); + PyObject* offset_only_clause = visitAsPyObject(offset_only_clause_ctx->columnExpr()); + PyObject_SetAttrString(select_query, "offset", offset_only_clause); + Py_DECREF(offset_only_clause); } } @@ -315,20 +344,20 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { Py_DECREF(select_query); throw HogQLSyntaxException("Using ARRAY JOIN without a FROM clause is not permitted"); } - PyObject_SetAttrString( - select_query, "array_join_op", - PyUnicode_FromString( - array_join_clause_ctx->LEFT() ? "LEFT ARRAY JOIN" - : array_join_clause_ctx->INNER() ? "INNER ARRAY JOIN" - : "ARRAY JOIN" - ) + PyObject* join_op = PyUnicode_FromString( + array_join_clause_ctx->LEFT() ? "LEFT ARRAY JOIN" + : array_join_clause_ctx->INNER() ? "INNER ARRAY JOIN" + : "ARRAY JOIN" ); + PyObject_SetAttrString(select_query, "array_join_op", join_op); + Py_DECREF(join_op); auto array_join_arrays_ctx = array_join_clause_ctx->columnExprList(); PyObject* array_join_list = visitAsPyObject(array_join_arrays_ctx); for (Py_ssize_t i = 0; i < PyList_Size(array_join_list); i++) { PyObject* expr = PyList_GET_ITEM(array_join_list, i); - if (!is_ast_node_instance(expr, "Alias")) { + bool is_alias = is_ast_node_instance(expr, "Alias"); + if (!is_alias) { Py_DECREF(array_join_list); Py_DECREF(select_query); auto relevant_column_expr_ctx = array_join_arrays_ctx->columnExpr(i); @@ -339,12 +368,15 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { } } PyObject_SetAttrString(select_query, "array_join_list", array_join_list); + Py_DECREF(array_join_list); } if (ctx->topClause()) { + Py_DECREF(select_query); throw HogQLNotImplementedException("Unsupported: SelectStmt.topClause()"); } if (ctx->settingsClause()) { + Py_DECREF(select_query); throw HogQLNotImplementedException("Unsupported: SelectStmt.settingsClause()"); } @@ -382,22 +414,29 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { PyObject* join2 = visitAsPyObject(ctx->joinExpr(1)); auto join_op_ctx = ctx->joinOp(); + PyObject* py_join_op; if (join_op_ctx) { - string join_op = visitAsString(join_op_ctx); - join_op.append(" JOIN"); - PyObject_SetAttrString(join2, "join_type", PyUnicode_FromStringAndSize(join_op.data(), join_op.size())); + string join_op = visitAsString(join_op_ctx) + " JOIN"; + py_join_op = PyUnicode_FromStringAndSize(join_op.data(), join_op.size()); } else { - PyObject_SetAttrString(join2, "join_type", PyUnicode_FromString("JOIN")); + py_join_op = PyUnicode_FromString("JOIN"); } - PyObject_SetAttrString(join2, "constraint", visitAsPyObject(ctx->joinConstraintClause())); + PyObject_SetAttrString(join2, "join_type", py_join_op); + Py_DECREF(py_join_op); + PyObject* constraint = visitAsPyObject(ctx->joinConstraintClause()); + PyObject_SetAttrString(join2, "constraint", constraint); + Py_DECREF(constraint); PyObject* last_join = join1; PyObject* next_join = PyObject_GetAttrString(last_join, "next_join"); while (!Py_IsNone(next_join)) { last_join = next_join; + Py_DECREF(next_join); next_join = PyObject_GetAttrString(last_join, "next_join"); } + Py_DECREF(next_join); PyObject_SetAttrString(last_join, "next_join", join2); + Py_DECREF(join2); return join1; } @@ -405,15 +444,14 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { VISIT(JoinExprTable) { PyObject* sample = visitAsPyObjectOrNone(ctx->sampleClause()); PyObject* table = visitAsPyObject(ctx->tableExpr()); - PyObject* table_final = Py_NewRef(ctx->FINAL() ? Py_True : Py_None); + PyObject* table_final = ctx->FINAL() ? Py_True : Py_None; if (is_ast_node_instance(table, "JoinExpr")) { - // visitTableExprAlias returns a JoinExpr to pass the alias - // visitTableExprFunction returns a JoinExpr to pass the args PyObject_SetAttrString(table, "table_final", table_final); PyObject_SetAttrString(table, "sample", sample); + Py_DECREF(sample); return table; } - return build_ast_node("JoinExpr", "{s:N,s:N,s:N}", "table", table, "table_final", table_final, "sample", sample); + return build_ast_node("JoinExpr", "{s:N,s:O,s:N}", "table", table, "table_final", table_final, "sample", sample); } VISIT(JoinExprParens) { return visit(ctx->joinExpr()); } @@ -421,15 +459,20 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { VISIT(JoinExprCrossOp) { PyObject* join1 = visitAsPyObject(ctx->joinExpr(0)); PyObject* join2 = visitAsPyObject(ctx->joinExpr(1)); - PyObject_SetAttrString(join2, "join_type", PyUnicode_FromString("CROSS JOIN")); + PyObject* join_type = PyUnicode_FromString("CROSS JOIN"); + PyObject_SetAttrString(join2, "join_type", join_type); + Py_DECREF(join_type); PyObject* last_join = join1; PyObject* next_join = PyObject_GetAttrString(last_join, "next_join"); while (!Py_IsNone(next_join)) { last_join = next_join; + Py_DECREF(next_join); next_join = PyObject_GetAttrString(last_join, "next_join"); } + Py_DECREF(next_join); PyObject_SetAttrString(last_join, "next_join", join2); + Py_DECREF(join2); return join1; } @@ -506,7 +549,9 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { Py_DECREF(column_expr_list); throw HogQLNotImplementedException("Unsupported: JOIN ... ON with multiple expressions"); } - return build_ast_node("JoinConstraint", "{s:N}", "expr", PyList_GET_ITEM(column_expr_list, 0)); + PyObject* expr = Py_NewRef(PyList_GET_ITEM(column_expr_list, 0)); + Py_DECREF(column_expr_list); + return build_ast_node("JoinConstraint", "{s:N}", "expr", expr); } VISIT(SampleClause) { @@ -532,7 +577,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { } else if (number_literal_ctxs.size() == 0) { throw HogQLParsingException("RatioExpr must have at least one number literal"); } - + auto left_ctx = number_literal_ctxs[0]; auto right_ctx = ctx->SLASH() && number_literal_ctxs.size() > 1 ? number_literal_ctxs[1] : NULL; @@ -548,13 +593,14 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { VISIT(WindowExpr) { auto frame_ctx = ctx->winFrameClause(); PyObject* frame = visitAsPyObjectOrNone(frame_ctx); + PyObject* frame_start = Py_NewRef(PyTuple_Check(frame) ? PyTuple_GetItem(frame, 0) : frame); + PyObject* frame_end = Py_NewRef(PyTuple_Check(frame) ? PyTuple_GetItem(frame, 1) : Py_None); + Py_DECREF(frame); PyObject* partition_by = visitAsPyObjectOrNone(ctx->winPartitionByClause()); PyObject* order_by = visitAsPyObjectOrNone(ctx->winOrderByClause()); PyObject* frame_method = frame_ctx && frame_ctx->RANGE() ? PyUnicode_FromString("RANGE") : frame_ctx && frame_ctx->ROWS() ? PyUnicode_FromString("ROWS") : Py_NewRef(Py_None); - PyObject* frame_start = PyTuple_Check(frame) ? PyTuple_GetItem(frame, 0) : frame; - PyObject* frame_end = PyTuple_Check(frame) ? PyTuple_GetItem(frame, 1) : Py_NewRef(Py_None); return build_ast_node( "WindowExpr", "{s:N,s:N,s:N,s:N,s:N}", "partition_by", partition_by, "order_by", order_by, "frame_method", frame_method, "frame_start", frame_start, "frame_end", frame_end @@ -577,7 +623,9 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { if (ctx->PRECEDING() || ctx->FOLLOWING()) { PyObject* number; if (ctx->numberLiteral()) { - number = PyObject_GetAttrString(visitAsPyObject(ctx->numberLiteral()), "value"); + PyObject* constant = visitAsPyObject(ctx->numberLiteral()); + number = PyObject_GetAttrString(constant, "value"); + Py_DECREF(constant); } else { number = Py_NewRef(Py_None); } @@ -606,7 +654,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { VISIT(ColumnExprTernaryOp) { return build_ast_node( - "Call", "{s:s, s:[O,O,O]}", "name", "if", "args", visitAsPyObject(ctx->columnExpr(0)), + "Call", "{s:s, s:[NNN]}", "name", "if", "args", visitAsPyObject(ctx->columnExpr(0)), visitAsPyObject(ctx->columnExpr(1)), visitAsPyObject(ctx->columnExpr(2)) ); } @@ -690,9 +738,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { PyObject_RichCompareBool(PyObject_GetAttrString(left, "name"), PyUnicode_FromString("concat"), Py_EQ)) { args = PyObject_GetAttrString(left, "args"); } else { - args = PyList_New(1); - PyList_SET_ITEM(args, 0, left); - Py_INCREF(left); // PyList_SET_ITEM doesn't increment refcount, as opposed to PyList_Append + args = Py_BuildValue("[O]", left); } if (is_ast_node_instance(right, "Call") && @@ -846,9 +892,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { if (is_ast_node_instance(left, "And")) { exprs = PyObject_GetAttrString(left, "exprs"); } else { - exprs = PyList_New(1); - PyList_SET_ITEM(exprs, 0, left); - Py_INCREF(left); + exprs = Py_BuildValue("[O]", left); } if (is_ast_node_instance(right, "And")) { PyObject* right_exprs = PyObject_GetAttrString(right, "exprs"); @@ -857,6 +901,8 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { } else { PyList_Append(exprs, right); } + Py_DECREF(right); + Py_DECREF(left); return build_ast_node("And", "{s:N}", "exprs", exprs); } @@ -868,9 +914,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { if (is_ast_node_instance(left, "Or")) { exprs = PyObject_GetAttrString(left, "exprs"); } else { - exprs = PyList_New(1); - PyList_SET_ITEM(exprs, 0, left); - Py_INCREF(left); + exprs = Py_BuildValue("[O]", left); } if (is_ast_node_instance(right, "Or")) { PyObject* right_exprs = PyObject_GetAttrString(right, "exprs"); @@ -879,6 +923,8 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { } else { PyList_Append(exprs, right); } + Py_DECREF(right); + Py_DECREF(left); return build_ast_node("Or", "{s:N}", "exprs", exprs); } @@ -899,21 +945,17 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { size_t columns_size = column_expr_ctx.size(); PyObject* columns = visitPyListOfObjects(column_expr_ctx); if (ctx->caseExpr) { - PyObject* args = PyList_New(4); - PyObject* arg_0 = Py_NewRef(PyList_GetItem(columns, 0)); + PyObject* arg_0 = PyList_GetItem(columns, 0); PyObject* arg_1 = build_ast_node("Array", "{s:[]}", "exprs"); PyObject* arg_2 = build_ast_node("Array", "{s:[]}", "exprs"); - PyObject* arg_3 = Py_NewRef(PyList_GetItem(columns, columns_size - 1)); - PyList_SET_ITEM(args, 0, arg_0); - PyList_SET_ITEM(args, 1, arg_1); - PyList_SET_ITEM(args, 2, arg_2); - PyList_SET_ITEM(args, 3, arg_3); - PyObject* expr_lists[2] = {PyObject_GetAttrString(arg_1, "exprs"), PyObject_GetAttrString(arg_2, "exprs")}; + PyObject* arg_3 = PyList_GetItem(columns, columns_size - 1); + PyObject* args = Py_BuildValue("[ONNO]", arg_0, arg_1, arg_2, arg_3); + PyObject* temp_expr_lists[2] = {PyObject_GetAttrString(arg_1, "exprs"), PyObject_GetAttrString(arg_2, "exprs")}; for (size_t index = 1; index < columns_size - 1; index++) { - PyList_Append(expr_lists[(index - 1) % 2], PyList_GetItem(columns, index)); + PyList_Append(temp_expr_lists[(index - 1) % 2], PyList_GetItem(columns, index)); } - Py_DECREF(expr_lists[0]); - Py_DECREF(expr_lists[1]); + Py_DECREF(temp_expr_lists[0]); + Py_DECREF(temp_expr_lists[1]); Py_DECREF(columns); return build_ast_node("Call", "{s:s,s:N}", "name", "transform", "args", args); } else { @@ -952,13 +994,10 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { VISIT(ColumnExprFunction) { string name = visitAsString(ctx->identifier()); - PyObject* parameters = visitAsPyObjectOrNone(ctx->columnExprList()); - auto column_arg_list_ctx = ctx->columnArgList(); - PyObject* args = visitAsPyObjectOrEmptyList(column_arg_list_ctx); - PyObject* distinct = ctx->DISTINCT() ? Py_True : Py_False; return build_ast_node( - "Call", "{s:s#,s:N,s:N,s:O}", "name", name.data(), name.size(), "params", parameters, "args", args, "distinct", - distinct + "Call", "{s:s#,s:N,s:N,s:O}", "name", name.data(), name.size(), "params", + visitAsPyObjectOrNone(ctx->columnExprList()), "args", visitAsPyObjectOrEmptyList(ctx->columnArgList()), + "distinct", ctx->DISTINCT() ? Py_True : Py_False ); } @@ -987,16 +1026,17 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { PyObject* cte = visitAsPyObject(with_expr_ctx); PyObject* name = PyObject_GetAttrString(cte, "name"); PyDict_SetItem(ctes, name, cte); + Py_DECREF(name); Py_DECREF(cte); } return ctes; } VISIT(WithExprSubquery) { - PyObject* subquery = visitAsPyObject(ctx->selectUnionStmt()); string name = visitAsString(ctx->identifier()); return build_ast_node( - "CTE", "{s:s#,s:N,s:s}", "name", name.data(), name.size(), "expr", subquery, "cte_type", "subquery" + "CTE", "{s:s#,s:N,s:s}", "name", name.data(), name.size(), "expr", visitAsPyObject(ctx->selectUnionStmt()), + "cte_type", "subquery" ); } @@ -1063,6 +1103,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { PyObject* py_alias = PyUnicode_FromStringAndSize(alias.data(), alias.size()); if (is_ast_node_instance(table, "JoinExpr")) { PyObject_SetAttrString(table, "alias", py_alias); + Py_DECREF(py_alias); return table; } return build_ast_node("JoinExpr", "{s:N,s:N}", "table", table, "alias", py_alias); @@ -1165,7 +1206,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor { VISIT(ColumnExprNullish) { return build_ast_node( - "Call", "{s:s, s:[O,O]}", "name", "ifNull", "args", visitAsPyObject(ctx->columnExpr(0)), + "Call", "{s:s, s:[NN]}", "name", "ifNull", "args", visitAsPyObject(ctx->columnExpr(0)), visitAsPyObject(ctx->columnExpr(1)) ); } @@ -1206,13 +1247,6 @@ class HogQLErrorListener : public antlr4::BaseErrorListener { } }; -HogQLParser get_parser(const char* statement) { - auto input_stream = new antlr4::ANTLRInputStream(statement, strnlen(statement, 65536)); - auto lexer = new HogQLLexer(input_stream); - auto stream = new antlr4::CommonTokenStream(lexer); - return HogQLParser(stream); -} - // MODULE STATE parser_state* get_module_state(PyObject* module) { @@ -1221,56 +1255,37 @@ parser_state* get_module_state(PyObject* module) { // MODULE METHODS -static PyObject* method_parse_expr(PyObject* self, PyObject* args) { - parser_state* state = get_module_state(self); - const char* str; - if (!PyArg_ParseTuple(args, "s", &str)) { - return NULL; - } - HogQLParser parser = get_parser(str); - parser.removeErrorListeners(); - parser.addErrorListener(new HogQLErrorListener(str)); - HogQLParser::ExprContext* parse_tree; - try { - parse_tree = parser.expr(); - } catch HANDLE_HOGQL_EXCEPTION(SyntaxException); - HogQLParseTreeConverter converter = HogQLParseTreeConverter(state); - return converter.visitAsPyObjectFinal(parse_tree); -} - -static PyObject* method_parse_order_expr(PyObject* self, PyObject* args) { - parser_state* state = get_module_state(self); - const char* str; - if (!PyArg_ParseTuple(args, "s", &str)) { - return NULL; - } - HogQLParser parser = get_parser(str); - parser.removeErrorListeners(); - parser.addErrorListener(new HogQLErrorListener(str)); - HogQLParser::OrderExprContext* parse_tree; - try { - parse_tree = parser.orderExpr(); - } catch HANDLE_HOGQL_EXCEPTION(SyntaxException); - HogQLParseTreeConverter converter = HogQLParseTreeConverter(state); - return converter.visitAsPyObjectFinal(parse_tree); -} - -static PyObject* method_parse_select(PyObject* self, PyObject* args) { - parser_state* state = get_module_state(self); - const char* str; - if (!PyArg_ParseTuple(args, "s", &str)) { - return NULL; - } - HogQLParser parser = get_parser(str); - parser.removeErrorListeners(); - parser.addErrorListener(new HogQLErrorListener(str)); - HogQLParser::SelectContext* parse_tree; - try { - parse_tree = parser.select(); - } catch HANDLE_HOGQL_EXCEPTION(SyntaxException); - HogQLParseTreeConverter converter = HogQLParseTreeConverter(state); - return converter.visitAsPyObjectFinal(parse_tree); -} +#define METHOD_PARSE_NODE(PascalCase, camelCase, snake_case) \ + static PyObject* method_parse_##snake_case(PyObject* self, PyObject* args) { \ + parser_state* state = get_module_state(self); \ + const char* str; \ + if (!PyArg_ParseTuple(args, "s", &str)) { \ + return NULL; \ + } \ + auto input_stream = new antlr4::ANTLRInputStream(str, strnlen(str, 65536)); \ + auto lexer = new HogQLLexer(input_stream); \ + auto stream = new antlr4::CommonTokenStream(lexer); \ + auto parser = new HogQLParser(stream); \ + parser->removeErrorListeners(); \ + auto error_listener = new HogQLErrorListener(str); \ + parser->addErrorListener(error_listener); \ + HogQLParser::PascalCase##Context* parse_tree; \ + try { \ + parse_tree = parser->camelCase(); \ + } catch HANDLE_HOGQL_EXCEPTION(SyntaxException); \ + HogQLParseTreeConverter converter = HogQLParseTreeConverter(state); \ + PyObject* result = converter.visitAsPyObjectFinal(parse_tree); \ + delete error_listener; \ + delete parser; \ + delete stream; \ + delete lexer; \ + delete input_stream; \ + return result; \ + } + +METHOD_PARSE_NODE(Expr, expr, expr) +METHOD_PARSE_NODE(OrderExpr, orderExpr, order_expr) +METHOD_PARSE_NODE(Select, select, select) static PyObject* method_unquote_string(PyObject* self, PyObject* args) { parser_state* state = get_module_state(self); diff --git a/hogql_parser/setup.py b/hogql_parser/setup.py index 6e9ee91b7475a..1a4ee2dfad726 100644 --- a/hogql_parser/setup.py +++ b/hogql_parser/setup.py @@ -32,7 +32,7 @@ setup( name="hogql_parser", - version="0.1.7", + version="0.1.8", url="https://github.com/PostHog/posthog/tree/master/hogql_parser", author="PostHog Inc.", author_email="hey@posthog.com", diff --git a/posthog/hogql/test/_test_parser.py b/posthog/hogql/test/_test_parser.py index fdd5dd946a2d1..765d4fbaab4de 100644 --- a/posthog/hogql/test/_test_parser.py +++ b/posthog/hogql/test/_test_parser.py @@ -6,11 +6,18 @@ from posthog.hogql.errors import HogQLException, SyntaxException from posthog.hogql.parser import parse_expr, parse_order_expr, parse_select from posthog.hogql.visitor import clear_locations -from posthog.test.base import BaseTest +from posthog.test.base import BaseTest, MemoryLeakTestMixin def parser_test_factory(backend: Literal["python", "cpp"]): - class TestParser(BaseTest): + base_classes = (MemoryLeakTestMixin, BaseTest) if backend == "cpp" else (BaseTest,) + + class TestParser(*base_classes): + MEMORY_INCREASE_PER_PARSE_LIMIT_B = 10_000 + MEMORY_INCREASE_INCREMENTAL_FACTOR_LIMIT = 0.1 + MEMORY_PRIMING_RUNS_N = 2 + MEMORY_LEAK_CHECK_RUNS_N = 100 + maxDiff = None def _expr(self, expr: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.Expr: diff --git a/posthog/test/base.py b/posthog/test/base.py index 5457bbe4056bc..d76a2e28b732d 100644 --- a/posthog/test/base.py +++ b/posthog/test/base.py @@ -1,6 +1,7 @@ import datetime as dt import inspect import re +import resource import threading import uuid from contextlib import contextmanager @@ -194,6 +195,41 @@ def validate_basic_html(self, html_message, site_url, preheader=None): self.assertIn(preheader, html_message) # type: ignore +class MemoryLeakTestMixin: + MEMORY_INCREASE_PER_PARSE_LIMIT_B: int + """Parsing more than once can never increase memory by this much (on average)""" + MEMORY_INCREASE_INCREMENTAL_FACTOR_LIMIT: float + """Parsing cannot increase memory by more than this factor * priming's increase (on average)""" + MEMORY_PRIMING_RUNS_N: int + """How many times to run every test method to prime the heap""" + MEMORY_LEAK_CHECK_RUNS_N: int + """How many times to run every test method to check for memory leaks""" + + def _callTestMethod(self, method): + mem_original_b = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + for _ in range(self.MEMORY_PRIMING_RUNS_N): # Priming runs + method() + mem_primed_b = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + for _ in range(self.MEMORY_LEAK_CHECK_RUNS_N): # Memory leak check runs + method() + mem_tested_b = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + avg_memory_priming_increase_b = (mem_primed_b - mem_original_b) / self.MEMORY_PRIMING_RUNS_N + avg_memory_test_increase_b = (mem_tested_b - mem_primed_b) / self.MEMORY_LEAK_CHECK_RUNS_N + avg_memory_increase_factor = ( + avg_memory_test_increase_b / avg_memory_priming_increase_b if avg_memory_priming_increase_b else 0 + ) + self.assertLessEqual( # type: ignore + avg_memory_test_increase_b, + self.MEMORY_INCREASE_PER_PARSE_LIMIT_B, + f"Possible memory leak - exceeded {self.MEMORY_INCREASE_PER_PARSE_LIMIT_B}-byte limit of incremental memory per parse", + ) + self.assertLessEqual( # type: ignore + avg_memory_increase_factor, + self.MEMORY_INCREASE_INCREMENTAL_FACTOR_LIMIT, + f"Possible memory leak - exceeded {self.MEMORY_INCREASE_INCREMENTAL_FACTOR_LIMIT*100:.2f}% limit of incremental memory per parse", + ) + + class BaseTest(TestMixin, ErrorResponsesMixin, TestCase): """ Base class for performing Postgres-based backend unit tests on. diff --git a/requirements.in b/requirements.in index ef24be343f467..a1b3cdf01340e 100644 --- a/requirements.in +++ b/requirements.in @@ -87,4 +87,4 @@ django-two-factor-auth==1.14.0 phonenumberslite==8.13.6 openai==0.27.8 nh3==0.2.14 -hogql-parser==0.1.7 +hogql-parser==0.1.8 diff --git a/requirements.txt b/requirements.txt index 6f118bcb411a6..58e89b2e25e52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -248,7 +248,7 @@ gunicorn==20.1.0 # via -r requirements.in h11==0.13.0 # via wsproto -hogql-parser==0.1.7 +hogql-parser==0.1.8 # via -r requirements.in idna==2.8 # via