diff --git a/src/optimizer/formatter.py b/src/optimizer/formatter.py index 7737f01..4fb68f8 100644 --- a/src/optimizer/formatter.py +++ b/src/optimizer/formatter.py @@ -62,6 +62,9 @@ def visit_long_literal(self, node, unmangle_names): def visit_double_literal(self, node, unmangle_names): return str(node.value) + def visit_field_type(self, node, unmangle_names): + return str(node) + def visit_time_literal(self, node, unmangle_names): return "TIME '%s'" % node.value diff --git a/src/parser/mysql_parser/lexer.py b/src/parser/mysql_parser/lexer.py index 25fbb2e..540ae95 100644 --- a/src/parser/mysql_parser/lexer.py +++ b/src/parser/mysql_parser/lexer.py @@ -49,6 +49,7 @@ 'SLASH', 'ASTERISK', 'NON_RESERVED', + 'NUMBER', 'QM', 'SCONST', ] @@ -99,13 +100,13 @@ def t_DOUBLE(t): if 'e' in t.value or 'E' in t.value or '.' in t.value: t.type = 'DOUBLE' else: - t.type = "INTEGER" + t.type = "NUMBER" return t -def t_INTEGER(t): +def t_NUMBER(t): r'\d+' - t.type = "INTEGER" + t.type = "NUMBER" return t diff --git a/src/parser/mysql_parser/parser.py b/src/parser/mysql_parser/parser.py index 0e4365c..5e8de37 100644 --- a/src/parser/mysql_parser/parser.py +++ b/src/parser/mysql_parser/parser.py @@ -57,6 +57,7 @@ from src.parser.tree.statement import Delete, Insert, Query, Update from src.parser.tree.table import Table, TableSubquery from src.parser.tree.values import Values +from src.parser.tree.field_type import UNSPECIFIEDLENGTH, FieldType, MySQLType from ply import yacc from src.optimizer.optimizer_enum import IndexType @@ -125,7 +126,7 @@ def p_create_table_end(p): r"""create_table_end : ENGINE EQ identifier create_table_end | DEFAULT CHARSET EQ identifier create_table_end | COLLATE EQ identifier create_table_end - | AUTO_INCREMENT EQ integer create_table_end + | AUTO_INCREMENT EQ number create_table_end | COMMENT EQ SCONST create_table_end | COMPRESSION EQ SCONST create_table_end | empty @@ -160,17 +161,17 @@ def p_column(p): def p_column_type(p): r""" column_type : INT column_end - | INT LPAREN integer RPAREN column_end + | INT LPAREN number RPAREN column_end | FLOAT column_end | BIGINT column_end - | BIGINT LPAREN integer RPAREN column_end - | TINYINT LPAREN integer RPAREN column_end + | BIGINT LPAREN number RPAREN column_end + | TINYINT LPAREN number RPAREN column_end | DATETIME column_end - | DATETIME LPAREN integer RPAREN column_end - | VARCHAR LPAREN integer RPAREN column_end - | CHAR LPAREN integer RPAREN column_end + | DATETIME LPAREN number RPAREN column_end + | VARCHAR LPAREN number RPAREN column_end + | CHAR LPAREN number RPAREN column_end | TIMESTAMP column_end - | DECIMAL LPAREN integer COMMA integer RPAREN column_end + | DECIMAL LPAREN number COMMA number RPAREN column_end """ p[0] = p[1].lower() @@ -412,7 +413,7 @@ def p_subquery(p): def p_for_update_opt(p): r"""for_update_opt : FOR UPDATE | FOR UPDATE NOWAIT - | FOR UPDATE WAIT integer + | FOR UPDATE WAIT number | LOCK IN SHARE MODE | empty""" if len(p) == 3: @@ -484,12 +485,12 @@ def p_null_ordering_opt(p): # LIMIT def p_limit_opt(p): - r"""limit_opt : LIMIT integer - | LIMIT integer COMMA integer + r"""limit_opt : LIMIT number + | LIMIT number COMMA number | LIMIT QM | LIMIT QM COMMA QM | LIMIT ALL - | LIMIT integer OFFSET integer + | LIMIT number OFFSET number | empty""" if len(p) < 5: p[0] = (0, p[2]) if p[1] else None @@ -500,8 +501,8 @@ def p_limit_opt(p): p[0] = (p[4], p[2]) -def p_integer(p): - r"""integer : INTEGER""" +def p_number(p): + r"""number : NUMBER""" p[0] = p[1] @@ -1083,7 +1084,7 @@ def p_base_primary_expression(p): def p_value(p): r"""value : NULL | SCONST - | number + | figure | boolean_value | QUOTED_IDENTIFIER | QM""" @@ -1149,7 +1150,7 @@ def p_searched_case(p): def p_cast_specification(p): - r"""cast_specification : CAST LPAREN value_expression AS data_type RPAREN""" + r"""cast_specification : CAST LPAREN expression AS cast_field RPAREN""" p[0] = Cast(p.lineno(1), p.lexpos(1), expression=p[3], data_type=p[5], safe=False) @@ -1182,34 +1183,111 @@ def p_call_list(p): _item_list(p) -def p_data_type(p): - r"""data_type : base_data_type type_param_list_opt""" - signature = p[1] - if p[2]: - # Normalize param list - type_params = [str(_type) for _type in p[2]] - signature += "(" + ','.join(type_params) + ")" - p[0] = signature +def p_cast_field(p): + r"""cast_field : BINARY field_len_opt + | char_type field_len_opt field_param_list_opt + | DATE + | YEAR + | DATETIME field_len_opt + | DECIMAL float_opt + | TIME field_len_opt + | SIGNED integer_opt + | UNSIGNED integer_opt + | JSON + | DOUBLE + | FLOAT float_opt + | REAL""" + field = FieldType(p.lineno(1), p.lexpos(1)) + if p.slice[1].type == "BINARY": + field.set_tp(MySQLType.BINARY, "BINARY") + field.set_length(p[2]) + elif p.slice[1].type == "char_type": + field.set_tp(MySQLType.CHAR, p[1]) + field.set_length(p[2]) + if p[3] != None: + field.set_charset_and_collation(f"({','.join(p[3])})") + elif p.slice[1].type == "DATE": + field.set_tp(MySQLType.DATE, "DATE") + elif p.slice[1].type == "YEAR": + field.set_tp(MySQLType.YEAR, "YEAR") + elif p.slice[1].type == 'DATETIME': + field.set_tp(MySQLType.DATETIME, "DATETIME") + field.set_length(p[2]) + elif p.slice[1].type == 'DECIMAL': + field.set_tp(MySQLType.DECIMAL, "DEMCIMAL") + field.set_length(p[2].length) + field.set_decimal(p[2].decimal) + elif p.slice[1].type == "TIME": + field.set_tp(MySQLType.TIME, "TIME") + field.set_length(p[2]) + elif p.slice[1].type == "SIGNED": + field.set_tp(MySQLType.INTEGER, p[2]) + elif p.slice[1].type == "UNSIGNED": + field.set_tp(MySQLType.INTEGER, p[2]) + elif p.slice[1].type == "JSON": + field.set_tp(MySQLType.JSON, "JSON") + elif p.slice[1].type == "DOUBLE": + field.set_tp(MySQLType.DOUBLE, "DOUBLE") + elif p.slice[1].type == "FLOAT": + field.set_tp(MySQLType.FLOAT, "FLOAT") + field.set_length(p[2].length) + field.set_decimal(p[2].decimal) + elif p.slice[1].type == "REAL": + field.set_tp(MySQLType.REAL, "REAL") + + +def p_field_len_opt(p): + r"""field_len_opt : LPAREN NUMBER RPAREN + | empty""" + if len(p) == 4: + p[0] = p[2].value & 0xFFFFFFFF # convert to unsigned int + p[0] = UNSPECIFIEDLENGTH -def p_type_param_list_opt(p): - r"""type_param_list_opt : LPAREN type_param_list RPAREN +def p_field_param_list_opt(p): + r"""field_param_list_opt : LPAREN field_param_list RPAREN | empty""" p[0] = p[2] if p[1] else p[1] -def p_type_param_list(p): - r"""type_param_list : type_param_list COMMA type_parameter - | type_parameter""" +def p_field_param_list(p): + r"""field_param_list : field_param_list COMMA field_parameter + | field_parameter""" _item_list(p) -def p_type_parameter(p): - r"""type_parameter : integer +def p_field_parameter(p): + r"""field_parameter : number | base_data_type""" p[0] = p[1] +def p_float_opt(p): + r"""float_opt : LPAREN NUMBER RPAREN + | LPAREN NUMBER COMMA NUMBER RPAREN + | empty""" + # First is length,Second is decimal + if len(p) == 2: + p[0] = {'length': UNSPECIFIEDLENGTH, 'decimal': UNSPECIFIEDLENGTH} + elif len(p) == 4: + p[0] = {'length': p[2], 'decimal': UNSPECIFIEDLENGTH} + elif len(p) == 6: + p[0] = {'length': p[2], 'decimal': p[4]} + + +def p_char_type(p): + r"""char_type : CHARACTER + | CHAR""" + p[0] = p[1] + + +def p_integer_opt(p): + r"""integer_opt : INTEGER + | INT + | empty""" + p[0] = p[1] + + def p_base_data_type(p): r"""base_data_type : identifier""" p[0] = p[1] @@ -1217,10 +1295,10 @@ def p_base_data_type(p): def p_date_time(p): r"""date_time : CURRENT_DATE - | CURRENT_TIME integer_param_opt - | CURRENT_TIMESTAMP integer_param_opt - | LOCALTIME integer_param_opt - | LOCALTIMESTAMP integer_param_opt""" + | CURRENT_TIME number_param_opt + | CURRENT_TIMESTAMP number_param_opt + | LOCALTIME number_param_opt + | LOCALTIMESTAMP number_param_opt""" precision = p[2] if len(p) == 3 else None p[0] = CurrentTime(p.lineno(1), p.lexpos(1), type=p[1], precision=precision) @@ -1247,8 +1325,8 @@ def p_boolean_value(p): p[0] = BooleanLiteral(p.lineno(1), p.lexpos(1), value=p[1]) -def p_integer_param_opt(p): - """integer_param_opt : LPAREN integer RPAREN +def p_number_param_opt(p): + """number_param_opt : LPAREN number RPAREN | LPAREN RPAREN | empty""" p[0] = int(p[2]) if len(p) == 4 else None @@ -1598,9 +1676,9 @@ def p_quoted_identifier(p): p[0] = p[1][1:-1] -def p_number(p): - r"""number : DOUBLE - | integer""" +def p_figure(p): + r"""figure : DOUBLE + | NUMBER""" if p.slice[1].type == "DOUBLE": p[0] = DoubleLiteral(p.lineno(1), p.lexpos(1), p[1]) else: diff --git a/src/parser/tree/expression.py b/src/parser/tree/expression.py index 3a86456..1e95610 100644 --- a/src/parser/tree/expression.py +++ b/src/parser/tree/expression.py @@ -212,12 +212,6 @@ def __init__(self, line=None, pos=None, type=None, left=None, right=None): def accept(self, visitor, context): return visitor.visit_logical_binary_expression(self, context) - # def and_op(self, left, right): - # return type == "AND" - # - # def or_op(self, left, right): - # return type == "OR" - class CoalesceExpression(Expression): def __init__(self, line=None, pos=None, operands=None): diff --git a/src/parser/tree/field_type.py b/src/parser/tree/field_type.py new file mode 100644 index 0000000..51138ab --- /dev/null +++ b/src/parser/tree/field_type.py @@ -0,0 +1,86 @@ +from .node import Node + +UNSPECIFIEDLENGTH = -1 + + +class MySQLType: + BINARY = 1 + CHAR = 2 + DATE = 3 + YEAR = 4 + DATETIME = 5 + DECIMAL = 6 + TIME = 7 + INTEGER = 8 + JSON = 9 + DOUBLE = 10 + FLOAT = 11 + REAL = 12 + + +class FieldType(Node): + def __init__( + self, + line=None, + pos=None, + ) -> None: + super(FieldType, self).__init__(line, pos) + + def set_tp(self, tp, type_name): + self.tp = tp + self.type_name = type_name + + def set_length(self, length): + self.length = length + + def set_decimal(self, decimal): + self.decimal = decimal + + def set_flag(self, flag): + self.flag = flag + + def set_is_signed(self, is_signed): + self.is_signed = is_signed + + def set_charset_and_collation(self, charset_and_collation): + self.charset_and_collation = charset_and_collation + + def accept(self, visitor, context): + return visitor.visit_field_type(self, context) + + def __str__(self): + result = "" + if self.type_name != None: + result = self.type_name + + if ( + self.tp is MySQLType.BINARY + or self.tp is MySQLType.CHAR + or self.tp is MySQLType.TIME + or self.tp is MySQLType.DATETIME + ): + if "length" in dir(FieldType) and self.length != UNSPECIFIEDLENGTH: + result += f" ({self.length})" + + if self.tp is MySQLType.CHAR: + if ( + "charset_and_collation" in dir(FieldType) + and self.charset_and_collation != None + ): + result += f" ({self.charset_and_collation})" + + if self.tp is MySQLType.INTEGER: + result = "SIGNED " + result if self.is_signed else "UNSIGNED " + result + + if self.tp is MySQLType.FLOAT or self.tp is MySQLType.DECIMAL: + if ( + "length" in dir(FieldType) + and self.length != UNSPECIFIEDLENGTH + and "decimal" in dir(FieldType) + and self.decimal != UNSPECIFIEDLENGTH + ): + result += f" ({self.length},{self.decimal})" + elif "length" in dir(FieldType) and self.length != UNSPECIFIEDLENGTH: + result += f" ({self.length})" + + return result diff --git a/src/parser/tree/visitor.py b/src/parser/tree/visitor.py index b3e9ada..094c889 100644 --- a/src/parser/tree/visitor.py +++ b/src/parser/tree/visitor.py @@ -485,6 +485,9 @@ def visit_if_expression(self, node, context): self.process(node.false_value, context) return None + def visit_field_type(self, node, unmangle_names): + return None + def visit_try_expression(self, node, context): self.process(node.inner_expression, context) return None diff --git a/test/metadata/test_metadata_utils.py b/test/metadata/test_metadata_utils.py index ac8c186..0259f9f 100644 --- a/test/metadata/test_metadata_utils.py +++ b/test/metadata/test_metadata_utils.py @@ -18,6 +18,7 @@ from src.metadata.catalog import Catalog from src.metadata.metadata_utils import MetaDataUtils from src.parser.mysql_parser.parser import parser +from src.parser.mysql_parser.lexer import lexer from src.parser.parser_utils import ParserUtils @@ -123,7 +124,7 @@ def setUpClass(self): self.catalog_object = MetaDataUtils.json_to_catalog( json.loads(self.catalog_json) ) - visitor = ParserUtils.format_statement(parser.parse(self.sql)) + visitor = ParserUtils.format_statement(parser.parse(self.sql, lexer=lexer)) self.table_list = visitor.table_list self.projection_column_list = visitor.projection_column_list self.order_list = visitor.order_list diff --git a/test/optimizer/test_formatter.py b/test/optimizer/test_formatter.py index 1f53fd7..0808d1f 100644 --- a/test/optimizer/test_formatter.py +++ b/test/optimizer/test_formatter.py @@ -2,6 +2,7 @@ from src.optimizer.formatter import format_sql from src.parser.mysql_parser.parser import parser +from src.parser.mysql_parser.lexer import lexer class MyTestCase(unittest.TestCase): @@ -10,7 +11,8 @@ def test_union_all(self): """ SELECT * FROM T1 WHERE C1 < 20000 UNION ALL SELECT * FROM T1 WHERE C2 < 30 AND LNNVL (C1 < 20000) - """ + """, + lexer=lexer, ) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -30,7 +32,8 @@ def test_union(self): """ SELECT * FROM T1 WHERE C1 < 20000 UNION SELECT * FROM T1 WHERE C2 < 30 - """ + """, + lexer=lexer, ) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -46,7 +49,7 @@ def test_union(self): ) def test_as(self): - statement = parser.parse("""SELECT a.* FROM d1 as a""") + statement = parser.parse("""SELECT a.* FROM d1 as a""", lexer=lexer) after_sql_rewrite_format = format_sql(statement, 0) assert ( after_sql_rewrite_format @@ -56,11 +59,11 @@ def test_as(self): ) def test_update(self): - statement = parser.parse("""update t set a = 1,b = 2 where c= 3""") + statement = parser.parse("""update t set a = 1,b = 2 where c= 3""", lexer=lexer) after_sql_rewrite_format = format_sql(statement, 0) assert after_sql_rewrite_format == """UPDATE t SET a = 1 , b = 2 WHERE c = 3""" statement = parser.parse( - """update t set a = 1,b = 2 where c= 3 order by c limit 1""" + """update t set a = 1,b = 2 where c= 3 order by c limit 1""", lexer=lexer ) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -71,12 +74,14 @@ def test_update(self): ) def test_delete(self): - statement = parser.parse("""delete from t where c= 3 and a = 1""") + statement = parser.parse("""delete from t where c= 3 and a = 1""", lexer=lexer) after_sql_rewrite_format = format_sql(statement, 0) assert after_sql_rewrite_format == """DELETE FROM t WHERE c = 3 AND a = 1""" statement = parser.parse( - """delete from t where c= 3 and a = 1 order by c limit 1""" + """delete from t where c= 3 and a = 1 order by c limit 1""", + lexer=lexer, + debug=True, ) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -88,7 +93,8 @@ def test_delete(self): def test_sql_1(self): statement = parser.parse( - """select tnt_inst_id as tnt_inst_id,gmt_create as gmt_create,gmt_modified as gmt_modified,principal_id as principal_id,version as version from cu_version_control where (principal_id = 'TOKENREL|100100000003358587777|IPAY_HK' )""" + """select tnt_inst_id as tnt_inst_id,gmt_create as gmt_create,gmt_modified as gmt_modified,principal_id as principal_id,version as version from cu_version_control where (principal_id = 'TOKENREL|100100000003358587777|IPAY_HK' )""", + lexer=lexer, ) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -108,7 +114,8 @@ def test_subquery_limit(self): statement = parser.parse( """ SELECT COUNT(*) FROM ( SELECT * FROM customs_script_match_history LIMIT ? ) a - """ + """, + lexer=lexer, ) after_sql_rewrite_format = format_sql(statement, 0) assert ( diff --git a/test/optimizer/test_pmd.py b/test/optimizer/test_pmd.py index b5e980c..fef92d3 100644 --- a/test/optimizer/test_pmd.py +++ b/test/optimizer/test_pmd.py @@ -11,67 +11,78 @@ PMDMultiTableRule, ) from src.parser.mysql_parser.parser import parser -from src.parser.mysql_parser import lexer +from src.parser.mysql_parser.lexer import lexer class MyTestCase(unittest.TestCase): def test_pmd_select_all_rule_match(self): - statement = parser.parse("SELECT * FROM T1 WHERE C1 < 20000 OR C2 < 30") + statement = parser.parse( + "SELECT * FROM T1 WHERE C1 < 20000 OR C2 < 30", lexer=lexer + ) pmd_result = PMDSelectAllRule().match(statement) assert pmd_result is True - statement = parser.parse("SELECT a FROM T1 WHERE C1 < 20000 OR C2 < 30") + statement = parser.parse( + "SELECT a FROM T1 WHERE C1 < 20000 OR C2 < 30", lexer=lexer + ) pmd_result = PMDSelectAllRule().match(statement) assert pmd_result is False - statement = parser.parse("SELECT a.* FROM T1 a WHERE C1 < 20000 OR C2 < 30") + statement = parser.parse( + "SELECT a.* FROM T1 a WHERE C1 < 20000 OR C2 < 30", lexer=lexer + ) pmd_result = PMDSelectAllRule().match(statement) assert pmd_result is True statement = parser.parse( - "SELECT a.* , a.b FROM T1 a WHERE C1 < 20000 OR C2 < 30" + "SELECT a.* , a.b FROM T1 a WHERE C1 < 20000 OR C2 < 30", lexer=lexer ) pmd_result = PMDSelectAllRule().match(statement) assert pmd_result is True def test_pmd_select_all_rule(self): - statement = parser.parse("SELECT * FROM T1 WHERE C1 < 20000 OR C2 < 30") + statement = parser.parse( + "SELECT * FROM T1 WHERE C1 < 20000 OR C2 < 30", lexer=lexer + ) pmd_result = PMDSelectAllRule().match_action(statement) assert pmd_result is not None def test_pmd_full_scan_rule(self): - statement = parser.parse("select 1 from a") + statement = parser.parse("select 1 from a", lexer=lexer) match = PMDFullScanRule().match(statement) assert match - statement = parser.parse("select 1 from a where b != 1") + statement = parser.parse("select 1 from a where b != 1", lexer=lexer) match = PMDFullScanRule().match(statement) assert match - statement = parser.parse("select 1 from a where b <> 1") + statement = parser.parse("select 1 from a where b <> 1", lexer=lexer) match = PMDFullScanRule().match(statement) assert match - statement = parser.parse("select 1 from a where b not like '1%' ") + statement = parser.parse("select 1 from a where b not like '1%' ", lexer=lexer) match = PMDFullScanRule().match(statement) assert match - statement = parser.parse("select 1 from a where b not in (1) ") + statement = parser.parse("select 1 from a where b not in (1) ", lexer=lexer) match = PMDFullScanRule().match(statement) assert match - statement = parser.parse("select 1 from a where not exists (select 1 from a) ") + statement = parser.parse( + "select 1 from a where not exists (select 1 from a) ", lexer=lexer + ) match = PMDFullScanRule().match(statement) assert match statement = parser.parse( - "select 1 from a where not exists (select 1 from a where c = 2) " + "select 1 from a where not exists (select 1 from a where c = 2) ", + lexer=lexer, ) match = PMDFullScanRule().match(statement) assert match - statement = parser.parse("select 1 from a where b like '%a' ") + statement = parser.parse("select 1 from a where b like '%a' ", lexer=lexer) match = PMDFullScanRule().match(statement) assert match - statement = parser.parse("select 1 from a where b like '%a%' ") + statement = parser.parse("select 1 from a where b like '%a%' ", lexer=lexer) match = PMDFullScanRule().match(statement) assert match @@ -79,13 +90,14 @@ def test_pmd_full_scan_rule(self): """SELECT * FROM product LEFT JOIN product_details ON (product.id = product_details.id) AND product.amount=200 -""" +""", + lexer=lexer, ) match = PMDFullScanRule().match(statement) assert not match statement = parser.parse( - """select 1 from a where b like '%a%' and c BETWEEN 1 AND 20""" + """select 1 from a where b like '%a%' and c BETWEEN 1 AND 20""", lexer=lexer ) match = PMDFullScanRule().match(statement) assert not match @@ -95,7 +107,8 @@ def test_update_delete(self): """update sqless_base set nick=1231 where - a = 1""" + a = 1""", + lexer=lexer, ) match = PMDFullScanRule().match(statement) assert not match @@ -103,41 +116,50 @@ def test_update_delete(self): """delete from sqless_base where - a = 1""" + a = 1""", + lexer=lexer, ) match = PMDFullScanRule().match(statement) assert not match def test_is_null(self): - statement = parser.parse("select * from sqless_base where a is null") + statement = parser.parse( + "select * from sqless_base where a is null", lexer=lexer + ) match = PMDIsNullRule().match(statement) assert not match - statement = parser.parse("select * from sqless_base where a = null") + statement = parser.parse( + "select * from sqless_base where a = null", lexer=lexer + ) match = PMDIsNullRule().match(statement) assert match def test_count(self): - statement = parser.parse("select count(a) from sqless_base") + statement = parser.parse("select count(a) from sqless_base", lexer=lexer) match = PMDCountRule().match(statement) assert match - statement = parser.parse("select count(1) from sqless_base") + statement = parser.parse( + "select count(1) from sqless_base", debug=True, lexer=lexer + ) match = PMDCountRule().match(statement) assert match - statement = parser.parse("select count(DISTINCT a) from sqless_base") + statement = parser.parse( + "select count(DISTINCT a) from sqless_base", lexer=lexer + ) match = PMDCountRule().match(statement) assert match - statement = parser.parse("select count(*) from sqless_base") + statement = parser.parse("select count(*) from sqless_base", lexer=lexer) match = PMDCountRule().match(statement) assert not match def test_arithmetic_binary(self): statement = parser.parse( - "select count(a) from sqless_base where a * 2 > 1", lexer=lexer.lexer + "select count(a) from sqless_base where a * 2 > 1", lexer=lexer ) match = PMDArithmeticRule().match(statement) assert match statement = parser.parse( - "select count(1) from sqless_base where a > 1 * 2", lexer=lexer.lexer + "select count(1) from sqless_base where a > 1 * 2", lexer=lexer ) match = PMDArithmeticRule().match(statement) assert not match @@ -147,7 +169,8 @@ def test_update_delete_multi_table(self): """DELETE FROM Product P LEFT JOIN OrderItem I ON P.Id = I.ProductId - WHERE I.Id IS NULL""" + WHERE I.Id IS NULL""", + lexer=lexer, ) match = PMDUpdateDeleteMultiTableRule().match(statement) assert match @@ -159,7 +182,8 @@ def test_update_delete_multi_table(self): SET o.total_orders = 7 ,item= 'pendrive' WHERE o.order_id = 1 - AND order_detail_id = 1""" + AND order_detail_id = 1""", + lexer=lexer, ) match = PMDUpdateDeleteMultiTableRule().match(statement) assert match @@ -168,21 +192,24 @@ def test_nowait_or_wait(self): statement = parser.parse( """ SELECT * FROM match_record_id FOR UPDATE - """ + """, + lexer=lexer, ) match = PMDNowaitWaitRule().match(statement) assert match statement = parser.parse( """ SELECT * FROM match_record_id FOR UPDATE NOWAIT - """ + """, + lexer=lexer, ) match = PMDNowaitWaitRule().match(statement) assert not match statement = parser.parse( """ SELECT * FROM match_record_id FOR UPDATE WAIT 1 - """ + """, + lexer=lexer, ) match = PMDNowaitWaitRule().match(statement) assert not match @@ -196,7 +223,8 @@ def test_multi_table(self): JOIN suppliers s ON p.supplier_id = s.supplier_id JOIN orders o ON p.product_id = o.product_id WHERE o.order_date BETWEEN '2022-01-01' AND '2022-12-31' -ORDER BY o.order_date DESC""" +ORDER BY o.order_date DESC""", + lexer=lexer, ) match = PMDMultiTableRule().match(statement) assert match @@ -208,7 +236,8 @@ def test_multi_table(self): JOIN products p ON o.product_id = p.product_id WHERE o.order_date BETWEEN '2022-01-01' AND '2022-12-31' ORDER BY o.order_date DESC -""" +""", + lexer=lexer, ) match = PMDMultiTableRule().match(statement) assert not match diff --git a/test/optimizer/test_rewrite.py b/test/optimizer/test_rewrite.py index f8b1c0e..3280a05 100644 --- a/test/optimizer/test_rewrite.py +++ b/test/optimizer/test_rewrite.py @@ -9,11 +9,14 @@ RewriteSupplementColumnRule, ) from src.parser.mysql_parser.parser import parser +from src.parser.mysql_parser.lexer import lexer class MyTestCase(unittest.TestCase): def test_or(self): - statement = parser.parse("SELECT * FROM T1 WHERE C1 < 20000 OR C2 < 30") + statement = parser.parse( + "SELECT * FROM T1 WHERE C1 < 20000 OR C2 < 30", lexer=lexer + ) RewriteMySQLORRule().match_action(statement) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -29,7 +32,7 @@ def test_or(self): ) def test_supplement_column_rewrite(self): - statement = parser.parse("SELECT * FROM sqless_base") + statement = parser.parse("SELECT * FROM sqless_base", lexer=lexer) catalog_json = """ {"columns": [{"schema":"sqless_test","table":"sqless_base", "name":"a","type":"int(2)","nullable":false},{"schema":"sqless_test","table":"sqless_base", @@ -51,7 +54,7 @@ def test_supplement_column_rewrite(self): ) def test_supplement_column_rewrite_rule_match(self): - statement = parser.parse("SELECT * FROM sqless_base") + statement = parser.parse("SELECT * FROM sqless_base", lexer=lexer) catalog_json = """ {"columns": [{"schema":"sqless_test","table":"sqless_base", "name":"a","type":"int(2)","nullable":false},{"schema":"sqless_test","table":"sqless_base", @@ -63,7 +66,7 @@ def test_supplement_column_rewrite_rule_match(self): catalog_object = MetaDataUtils.json_to_catalog(json.loads(catalog_json)) match = RewriteSupplementColumnRule().match(statement, catalog_object) assert match is True - statement = parser.parse("SELECT a FROM sqless_base") + statement = parser.parse("SELECT a FROM sqless_base", lexer=lexer) match = RewriteSupplementColumnRule().match(statement, catalog_object) assert match is False @@ -89,7 +92,7 @@ def test_completion_column(self): } """ catalog_object = MetaDataUtils.json_to_catalog(json.loads(catalog_json)) - statement = parser.parse("SELECT * FROM d1") + statement = parser.parse("SELECT * FROM d1", lexer=lexer) RewriteSupplementColumnRule().match_action(statement, catalog_object) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -100,7 +103,7 @@ def test_completion_column(self): FROM d1""" ) - statement = parser.parse("SELECT a.* FROM d1 a") + statement = parser.parse("SELECT a.* FROM d1 a", lexer=lexer) RewriteSupplementColumnRule().match_action(statement, catalog_object) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -111,7 +114,7 @@ def test_completion_column(self): FROM d1 a""" ) - statement = parser.parse("SELECT a.* FROM d1 as a") + statement = parser.parse("SELECT a.* FROM d1 as a", lexer=lexer) RewriteSupplementColumnRule().match_action(statement, catalog_object) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -122,7 +125,7 @@ def test_completion_column(self): FROM d1 AS a""" ) - statement = parser.parse("SELECT c.* , d2.b FROM a.d1 c,d2") + statement = parser.parse("SELECT c.* , d2.b FROM a.d1 c,d2", lexer=lexer) RewriteSupplementColumnRule().match_action(statement, catalog_object) after_sql_rewrite_format = format_sql(statement, 0) assert ( @@ -142,22 +145,30 @@ def test_or_same_column(self): T1 WHERE C1 IN (20000, 30)""" - statement = parser.parse("SELECT * FROM T1 WHERE C1 = 20000 OR C1 = 30") + statement = parser.parse( + "SELECT * FROM T1 WHERE C1 = 20000 OR C1 = 30", lexer=lexer + ) RewriteMySQLORRule().match_action(statement) result = format_sql(statement, 0) assert result == after_sql_rewrite_format - statement = parser.parse("SELECT * FROM T1 WHERE C1 in (20000) OR C1 = 30") + statement = parser.parse( + "SELECT * FROM T1 WHERE C1 in (20000) OR C1 = 30", lexer=lexer + ) RewriteMySQLORRule().match_action(statement) result = format_sql(statement, 0) assert result == after_sql_rewrite_format - statement = parser.parse("SELECT * FROM T1 WHERE C1 in (20000) OR C1 in (30)") + statement = parser.parse( + "SELECT * FROM T1 WHERE C1 in (20000) OR C1 in (30)", lexer=lexer + ) RewriteMySQLORRule().match_action(statement) result = format_sql(statement, 0) assert result == after_sql_rewrite_format - statement = parser.parse("SELECT * FROM T1 WHERE C1 in (20000) OR C2 in (30)") + statement = parser.parse( + "SELECT * FROM T1 WHERE C1 in (20000) OR C2 in (30)", lexer=lexer + ) RewriteMySQLORRule().match_action(statement) result = format_sql(statement, 0) assert ( @@ -173,7 +184,9 @@ def test_or_same_column(self): ) def test_like(self): - statement = parser.parse("select * from sqless_base where d like 'a%'") + statement = parser.parse( + "select * from sqless_base where d like 'a%'", lexer=lexer + ) catalog_json = """ { "columns": @@ -209,7 +222,8 @@ def test_like(self): def test_qm(self): statement = parser.parse( """select * FROM cm_relation WHERE status = ? - AND primary_id = ? AND rel_type = ? AND rel_biz_type = ?""" + AND primary_id = ? AND rel_type = ? AND rel_biz_type = ?""", + lexer=lexer, ) catalog_json = """ {"columns": [{"schema":"luli1","table":"cm_relation", @@ -285,20 +299,22 @@ def test_qm(self): ) def test_delete_update_order(self): - statement = parser.parse("""delete from tbl where col1 = ? order by col""") + statement = parser.parse( + """delete from tbl where col1 = ? order by col""", lexer=lexer + ) match = RemoveOrderByInDeleteUpdateRule().match(statement, None) assert match RemoveOrderByInDeleteUpdateRule().match_action(statement, None) after_sql_rewrite_format = format_sql(statement, 0) assert after_sql_rewrite_format == """DELETE FROM tbl WHERE col1 = ?""" statement = parser.parse( - """delete from tbl where col1 = ? order by col limit 1""" + """delete from tbl where col1 = ? order by col limit 1""", lexer=lexer ) match = RemoveOrderByInDeleteUpdateRule().match(statement, None) assert not match statement = parser.parse( - """update tbl set col1 = ? where col2 = ? order by col""" + """update tbl set col1 = ? where col2 = ? order by col""", lexer=lexer ) match = RemoveOrderByInDeleteUpdateRule().match(statement, None) assert match @@ -306,14 +322,16 @@ def test_delete_update_order(self): after_sql_rewrite_format = format_sql(statement, 0) assert after_sql_rewrite_format == """UPDATE tbl SET col1 = ? WHERE col2 = ?""" statement = parser.parse( - """update tbl set col1 = ? where col2 = ? order by col limit 1""" + """update tbl set col1 = ? where col2 = ? order by col limit 1""", + lexer=lexer, ) match = RemoveOrderByInDeleteUpdateRule().match(statement, None) assert not match def test_subquery_or(self): statement = parser.parse( - "SELECT t1.* FROM t1 WHERE t1.c1 IN (?) AND t1.c2 = ? AND t1.c3 > ? and c4 not in (select t2.c5 from t2 where t2.c5 = ? or t2.c5 = ?)" + "SELECT t1.* FROM t1 WHERE t1.c1 IN (?) AND t1.c2 = ? AND t1.c3 > ? and c4 not in (select t2.c5 from t2 where t2.c5 = ? or t2.c5 = ?)", + lexer=lexer, ) is_match = RewriteMySQLORRule().match(statement) assert not is_match diff --git a/test/parser/test_parser_ddl.py b/test/parser/test_parser_ddl.py index 0d6f700..28e356f 100644 --- a/test/parser/test_parser_ddl.py +++ b/test/parser/test_parser_ddl.py @@ -15,6 +15,7 @@ import unittest from src.parser.mysql_parser.parser import parser +from src.parser.mysql_parser.lexer import lexer class MyTestCase(unittest.TestCase): @@ -39,7 +40,8 @@ def test_create_table(self): UNIQUE KEY `sql_id` (`cluster`, `tenant_name`, `sql_id`), KEY `pure_dbname` (`pure_dbname`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8 - """ + """, + lexer=lexer, ) assert result['index_list'][0][0].value == '1.primary' assert result['index_list'][0][1] == 'PRIMARY' @@ -173,7 +175,8 @@ def test_create_table2(self): KEY `i_type_state_id` (`occupy_type`,`order_state`,`ticket_machine_id`,`merchant_id`,`env`,`gmt_distribute`), KEY `I_machine_id_state_type_env_created` (`ticket_machine_id`,`merchant_id`,`order_state`,`occupy_type`,`env`,`gmt_created`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin AUTO_INCREMENT=3529344021213246782 COMMENT='订单表:订单相关信息存储' - """ + """, + lexer=lexer, ) assert len(result['index_list']) == 29 assert result['index_list'][0][0].value == '1.primary' diff --git a/test/parser/test_parser_dml.py b/test/parser/test_parser_dml.py index f091367..6fed7d0 100644 --- a/test/parser/test_parser_dml.py +++ b/test/parser/test_parser_dml.py @@ -177,6 +177,17 @@ def test_mysql_contact_function(self): result = mysql_parser.parse(sql, lexer=mysql_lexer) assert isinstance(result, Statement) + def test_mysql_cast_function(self): + test_sqls = [ + "SELECT CAST(CAST(1+2 AS TIME) AS JSON)", + "SELECT CAST((8+1) AS SIGNED)", + "SELECT CAST(9 AS TIME)", + "SELECT CAST(1 BETWEEN 1 AND 2 AS SIGNED)", + ] + for sql in test_sqls: + result = mysql_parser.parse(sql, lexer=mysql_lexer, debug=True) + assert isinstance(result, Statement) + if __name__ == "__main__": unittest.main() diff --git a/test/parser/test_parser_utils.py b/test/parser/test_parser_utils.py index caf7641..6e6aef7 100644 --- a/test/parser/test_parser_utils.py +++ b/test/parser/test_parser_utils.py @@ -16,6 +16,7 @@ from src.optimizer.formatter import format_sql from src.parser.mysql_parser.parser import parser +from src.parser.mysql_parser.lexer import lexer from src.parser.parser_utils import ParserUtils @@ -26,7 +27,7 @@ def test_get_filter_column(self): "where a.b = 1 and b.c = 2 group by name,age " "having count(*)>2 and avg(age)<20 order by a asc,b desc limit 1,10" ) - visitor = ParserUtils.format_statement(parser.parse(sql)) + visitor = ParserUtils.format_statement(parser.parse(sql, lexer=lexer)) table_list = visitor.table_list projection_column_list = visitor.projection_column_list order_list = visitor.order_list @@ -117,7 +118,7 @@ def test_get_filter_column2(self): tars_sqldiag_all.tenant_name, tars_sqldiag_all.sql_id, tars_sqldiag_all.diag_type """ - visitor = ParserUtils.format_statement(parser.parse(sql)) + visitor = ParserUtils.format_statement(parser.parse(sql, lexer=lexer)) table_list = visitor.table_list projection_column_list = visitor.projection_column_list order_list = visitor.order_list @@ -188,7 +189,7 @@ def test_parameterized_query(self): "where a.b = 1 and b.c = 2 and a.d in ('2','3','6') group by name,age " "having count(*)>2 and avg(age)<20 order by a asc,b desc limit 3,10" ) - statement_node = ParserUtils.parameterized_query(parser.parse(sql)) + statement_node = ParserUtils.parameterized_query(parser.parse(sql, lexer=lexer)) format_sql(statement_node, 0) def test_parameterized_query2(self): @@ -204,19 +205,19 @@ def test_parameterized_query2(self): when -2 then 15 else merge_record.merge_result END as merge_result, server_release_repo.completed, server_release_repo.create_time, server_release_repo.update_time FROM server_release_repo left join merge_record on server_release_repo.merge_record_id = merge_record.id WHERE 1 = 1 and integrate = 0 and completed = 1 and deleted = 0 and merge_record_id != -1 """ - statement_node = ParserUtils.parameterized_query(parser.parse(sql)) + statement_node = ParserUtils.parameterized_query(parser.parse(sql, lexer=lexer)) format_sql(statement_node, 0) def test_parameterized_query3(self): sql = """select id,gmt_create,gmt_modified,proj_code,matter_code,'ATUSER' act_type,content,operator,operator_no,status,biz_code,biz_id,content_detail from lc_opr_biz_activity where id in ( select max(id) id from lc_opr_biz_activity t1 join ( select act_type ,biz_activity_id,task_id from lc_opr_schedule where user_id = '291909' and status = '00' and act_type = 'ATUSER' and matter_code = 'M210713P0689I00007' and task_id is not null ) t2 on t1.id = t2.biz_activity_id group by t2.task_id ) union select id,gmt_create,gmt_modified,proj_code,matter_code,act_type,content,operator,operator_no,status,biz_code,biz_id,content_detail from lc_opr_biz_activity where id in( select max(id) id from lc_opr_biz_activity where matter_code = 'M210713P0689I00007' and biz_code = 'TASK' and biz_id not in ( select distinct task_id from lc_opr_schedule where user_id = '291909' and status = '00' and act_type = 'ATUSER' and matter_code = 'M210713P0689I00007' and task_id is not null ) group by biz_id )""" - statement_node = ParserUtils.parameterized_query(parser.parse(sql)) + statement_node = ParserUtils.parameterized_query(parser.parse(sql, lexer=lexer)) format_sql(statement_node, 0) def test_subquery_expression(self): sql = """ SELECT COUNT(*) FROM ( SELECT * FROM customs_script_match_history LIMIT ? ) a """ - ParserUtils.format_statement(parser.parse(sql)) + ParserUtils.format_statement(parser.parse(sql, lexer=lexer)) def test_sql_1(self): sql = """ @@ -251,11 +252,11 @@ def test_sql_1(self): AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) UNION SELECT t1.idc, t2.ds AS ds, SUM(t2.yhat) AS disk FROM `sync_mt_mysql_meta` t1, space_used_forecast_per_inst t2 WHERE t2.node_name = t1.node AND t2.gmt_create = ? AND t1.idc IS NOT NULL AND t1.cluster_name NOT IN (?) AND t2.ds = ? AND t1.idc IN (?) AND t1.nc_ip NOT IN ( SELECT DISTINCT ip FROM yusuan_unires_docker_nc_host WHERE pool LIKE ? ) """ - ParserUtils.format_statement(parser.parse(sql)) + ParserUtils.format_statement(parser.parse(sql, lexer=lexer)) def test_recursion_error(self): sql = """SELECT id, `table_name`, version, primary_id, template , template_md5, security_level, `nullable`, status, `param_group` , description, operator, global_id, govern_type, utc_create , utc_modified FROM param_template WHERE (table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ? OR table_name = ? AND version = ?) AND status IN (?) ORDER BY table_name ASC, utc_create DESC LIMIT ?, ?""" - statement = parser.parse(sql) + statement = parser.parse(sql, lexer=lexer) ParserUtils.format_statement(statement) def test_in_subquery(self): @@ -263,7 +264,7 @@ def test_in_subquery(self): 'select sum(cost) from costs where eventtype = \'treatment\' and eventid in ' '(select treatmentid from treatment where treatmentname = \'bleeding scan\')' ) - visitor = ParserUtils.format_statement(parser.parse(sql)) + visitor = ParserUtils.format_statement(parser.parse(sql, lexer=lexer)) table_list = visitor.table_list assert table_list == [ { @@ -285,7 +286,7 @@ def test_update_parameterize(self): sql = """ UPDATE `t1` SET `c`='11' WHERE (`id`='1111111') """ - statement_node = ParserUtils.parameterized_query(parser.parse(sql)) + statement_node = ParserUtils.parameterized_query(parser.parse(sql, lexer=lexer)) statement = format_sql(statement_node, 0) assert statement == """UPDATE t1 SET c = ? WHERE id = ?""" @@ -293,7 +294,7 @@ def test_between(self): sql = """ select max(`successRate`) AS `successRate` from `table_850d` where `period` between '2022-07-11 00:00:00' and '2022-07-11 23:59:59' and `successRate` < 0.35; """ - visitor = ParserUtils.format_statement(parser.parse(sql)) + visitor = ParserUtils.format_statement(parser.parse(sql, lexer=lexer)) table_list = visitor.table_list assert table_list == [ {