diff --git a/mo_sql_parsing/sql_parser.py b/mo_sql_parsing/sql_parser.py index b9524c7..4c0fe31 100644 --- a/mo_sql_parsing/sql_parser.py +++ b/mo_sql_parsing/sql_parser.py @@ -745,7 +745,9 @@ def mult(tokens): ) use_schema = assign("use", identifier) - + open_cursor = assign("open", identifier) + close_cursor = assign("close", identifier) + fetch_cursor = assign("fetch", identifier) + INTO + delimited_list(ident) cache_options = Optional(( keyword("options").suppress() + LB @@ -766,7 +768,8 @@ def mult(tokens): drops = assign( "drop", - MatchFirst([ + temporary + +MatchFirst([ keyword(item).suppress() + Optional(flag("if exists")) + Group(identifier)(item) for item in ["table", "view", "index", "schema"] ]), @@ -777,6 +780,7 @@ def mult(tokens): insert = ( Optional(assign("with", with_clause)) + keyword("insert").suppress() + + Optional(flag("ignore")) + ( flag("overwrite") + keyword("table").suppress() | keyword("into").suppress() + Optional(keyword("table").suppress()) @@ -788,6 +792,15 @@ def mult(tokens): + returning ) / to_insert_call + replace = ( + Optional(assign("with", with_clause)) + + keyword("replace").suppress() + + Optional(keyword("into").suppress()) + + identifier("table") + + Optional(LB + delimited_list(identifier)("columns") + RB) + + (values | query)("query") + ) / to_replace_call + update = ( keyword("update")("op") + (delimited_list(table_source) + ZeroOrMore(join))("params") @@ -892,6 +905,7 @@ def mult(tokens): ############################################################# # GET/SET ############################################################# + statement = Forward() special_ident = keyword("masking policy") | identifier / (lambda t: t[0].lower()) declare_variable = assign("declare", column_definition) set_one_variable = SET + ( @@ -951,7 +965,6 @@ def mult(tokens): )) # EXPLAIN - statement = Forward() explain_option = MatchFirst([ ( Keyword(option, caseless=True) @@ -1034,7 +1047,7 @@ def mult(tokens): ############################################################# many_command = Forward() - block = Group(Optional(identifier("label") + ":") + BEGIN + Group(many_command)("block") + END) + block = BEGIN + Group(many_command)("block") + END if_block = ( assign("if", expression) + assign("then", many_command) @@ -1042,6 +1055,16 @@ def mult(tokens): + keyword("end if").suppress() ) leave = assign("leave", identifier) + while_block = ( + assign("while", expression) + + assign("do", many_command) + + keyword("end while").suppress() + ) + loop_block = ( + keyword("loop").suppress() + + many_command("loop") + + keyword("end loop").suppress() + ) create_trigger = assign( "create trigger", @@ -1074,7 +1097,7 @@ def mult(tokens): + keyword("procedure") + identifier("name") + LB - + Group(delimited_list(proc_param))("params") + + Group(Optional(delimited_list(proc_param)))("params") + RB + characteristic + statement("body") @@ -1110,13 +1133,26 @@ def mult(tokens): + statement("body") )("declare_handler") + declare_cursor = Group( + keyword("declare").suppress() + + identifier("name") + + keyword("cursor for").suppress() + + query("query") + )("declare_cursor") + + transact = ( Group(keyword("start transaction")("op")) / to_json_call | Group(keyword("commit")("op")) / to_json_call | Group(keyword("rollback")("op")) / to_json_call ) - flow = block | if_block | leave | assign("return", expression) + blocks = Group(Optional(identifier("label") + ":") + ( + block + | if_block + | while_block + | loop_block + )) ############################################################# # FINALLY ASSEMBLE THE PARSER @@ -1131,7 +1167,7 @@ def mult(tokens): statement << Group( query - | (insert | update | delete | merge | truncate | use_schema) + | (insert | replace | update | delete | merge | truncate | use_schema) | (create_table | create_view | create_cache | create_index | create_schema) | drops | (copy | alter) @@ -1141,8 +1177,14 @@ def mult(tokens): | explain | delimiter_command | declare_hanlder - | flow + | declare_cursor + | leave + | assign("return", expression) | transact + | open_cursor + | close_cursor + | fetch_cursor + | blocks | (Optional(keyword("alter session")).suppress() + (set_variables | unset_one_variable | declare_variable)) ) diff --git a/mo_sql_parsing/utils.py b/mo_sql_parsing/utils.py index 2a0f87d..9fe6e93 100644 --- a/mo_sql_parsing/utils.py +++ b/mo_sql_parsing/utils.py @@ -690,6 +690,24 @@ def to_insert_call(tokens): return Call("insert", [tokens["table"]], {"columns": columns, "query": query, **options}) +def to_replace_call(tokens): + options = {k: v for k, v in tokens.items() if k not in ["columns", "table", "query"]} + query = tokens["query"] + columns = tokens["columns"] + try: + values = query["from"]["literal"] + if values: + if columns: + data = [dict(zip(columns, row)) for row in values] + return Call("replace", [tokens["table"]], {"values": data, **options}) + else: + return Call("replace", [tokens["table"]], {"values": values, **options}) + except Exception: + pass + + return Call("replace", [tokens["table"]], {"columns": columns, "query": query, **options}) + + def to_update_call(tokens): value = tokens["value"] name = tokens["name"] diff --git a/packaging/setup.py b/packaging/setup.py index 4c51418..f7d48cc 100644 --- a/packaging/setup.py +++ b/packaging/setup.py @@ -15,6 +15,6 @@ name='mo-sql-parsing', packages=["mo_sql_parsing"], url='https://github.com/klahnakoski/mo-sql-parsing', - version='10.640.24140', + version='10.641.24143', zip_safe=True ) \ No newline at end of file diff --git a/packaging/setuptools.json b/packaging/setuptools.json index 9ff72a4..23f5654 100644 --- a/packaging/setuptools.json +++ b/packaging/setuptools.json @@ -311,6 +311,6 @@ "name": "mo-sql-parsing", "packages": ["mo_sql_parsing"], "url": "https://github.com/klahnakoski/mo-sql-parsing", - "version": "10.640.24140", + "version": "10.641.24143", "zip_safe": true } \ No newline at end of file diff --git a/tests/test_mysql.py b/tests/test_mysql.py index d35bf9b..ef767c2 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -6,9 +6,9 @@ # Author: Kyle Lahnakoski (kyle@lahnakoski.com) # from itertools import zip_longest +from unittest import skip from mo_files import File -from mo_parsing.debug import Debugger from mo_testing.fuzzytestcase import add_error_reporting, FuzzyTestCase, assertAlmostEqual from mo_sql_parsing import parse, parse_mysql, parse_delimiters @@ -937,4 +937,243 @@ def test_issue_218_transaction(self): {"commit": {}}, {"rollback": {}}, ] - self.assertEqual(result, expected) \ No newline at end of file + self.assertEqual(result, expected) + + """ + Tests to cover some permutations on the REPLACE statement with this grammer + 15.2.12 REPLACE Statement + + REPLACE [LOW_PRIORITY | DELAYED] + [INTO] tbl_name + [PARTITION (partition_name [, partition_name] ...)] + [(col_name [, col_name] ...)] + { {VALUES | VALUE} (value_list) [, (value_list)] ... + | + VALUES row_constructor_list + } + + REPLACE [LOW_PRIORITY | DELAYED] + [INTO] tbl_name + [PARTITION (partition_name [, partition_name] ...)] + SET assignment_list + + REPLACE [LOW_PRIORITY | DELAYED] + [INTO] tbl_name + [PARTITION (partition_name [, partition_name] ...)] + [(col_name [, col_name] ...)] + {SELECT ... | TABLE table_name} + + value: + {expr | DEFAULT} + + value_list: + value [, value] ... + + row_constructor_list: + ROW(value_list)[, ROW(value_list)][, ...] + + assignment: + col_name = value + + assignment_list: + assignment [, assignment] ... + + """ + + def test_pr_236_replace1(self): + sql = """ + REPLACE INTO test_table (id, name) VALUES (1, 'New Name') + """ + result = parse(sql) + expected = { + "replace": "test_table", + "columns": ["id", "name"], + "query": {"select": [{"value": 1}, {"value": {"literal": "New Name"}}]}, + } + self.assertEqual(result, expected) + + def test_pr_236_replace2(self): + sql = """REPLACE INTO tab (name) VALUES(42)""" + result = parse(sql) + expected = { + "columns": "name", + "replace": "tab", + "query": {"select": {"value": 42}}, + } + self.assertEqual(result, expected) + + def test_pr_236_replace3(self): + sql = """replace into t (a, b, c) select x, y, z from f""" + result = parse(sql) + expected = { + "columns": ["a", "b", "c"], + "replace": "t", + "query": {"from": "f", "select": [{"value": "x"}, {"value": "y"}, {"value": "z"}]}, + } + self.assertEqual(result, expected) + + def test_pr_236_replace4(self): + sql = """REPLACE Person(Id, Name, DateOfBirth, Gender) + VALUES (1, 'John Lennon', '1940-10-09', 'M'), (2, 'Paul McCartney', '1942-06-18', 'M'), + (3, 'George Harrison', '1943-02-25', 'M'), (4, 'Ringo Starr', '1940-07-07', 'M')""" + result = parse(sql) + expected = { + "replace": "Person", + "values": [ + {"DateOfBirth": "1940-10-09", "Gender": "M", "Id": 1, "Name": "John Lennon"}, + {"DateOfBirth": "1942-06-18", "Gender": "M", "Id": 2, "Name": "Paul McCartney"}, + {"DateOfBirth": "1943-02-25", "Gender": "M", "Id": 3, "Name": "George Harrison"}, + {"DateOfBirth": "1940-07-07", "Gender": "M", "Id": 4, "Name": "Ringo Starr"}, + ], + } + self.assertEqual(result, expected) + + def test_pr_236_while1(self): + sql = """ + CREATE PROCEDURE dowhile() + BEGIN + DECLARE v1 INT DEFAULT 5; + + WHILE v1 > 0 DO + SET v1 = v1 - 1; + END WHILE; + END;""" + + result = parse(sql) + expected = {"create_procedure": { + "name": "dowhile", + "body": {"block": [ + {"declare": {"name": "v1", "default": 5, "type": {"int": {}}}}, + {"while": {"gt": ["v1", 0]}, "do": {"set": {"v1": {"sub": ["v1", 1]}}},}, + ]}, + }} + self.assertEqual(result, expected) + + def test_pr_236_while2(self): + sql = """ + CREATE PROCEDURE dowhile() + BEGIN + DECLARE v1 INT DEFAULT 5; + + WHILE v1 > 0 DO + SET v1 = v1 - 1; + SET v2 = v2 + 1; + END WHILE; + END;""" + + result = parse(sql) + expected = {"create_procedure": { + "name": "dowhile", + "body": {"block": [ + {"declare": {"name": "v1", "default": 5, "type": {"int": {}}}}, + { + "while": {"gt": ["v1", 0]}, + "do": [{"set": {"v1": {"sub": ["v1", 1]}}}, {"set": {"v2": {"add": ["v2", 1]}}},], + }, + ]}, + }} + self.assertEqual(result, expected) + + def test_pr_236_cursor(self): + sql = """ + CREATE PROCEDURE demo() + BEGIN + DECLARE done INT DEFAULT FALSE; + DECLARE a CHAR(16); + DECLARE cur1 CURSOR FOR SELECT id,data FROM test.t1; + DECLARE cur2 CURSOR FOR SELECT i FROM test.t2; + DECLARE CONTINUE HANDLER FOR NOT FOUND SET done = TRUE; + + OPEN cur1; + OPEN cur2; + + read_loop: LOOP + FETCH cur1 INTO a, b; + FETCH cur2 INTO c; + IF done THEN + LEAVE read_loop; + END IF; + IF b < c THEN + INSERT INTO test.t3 VALUES (a,b); + ELSE + INSERT INTO test.t3 VALUES (a,c); + END IF; + END LOOP; + + CLOSE cur1; + CLOSE cur2; + END; + """ + result = parse(sql) + expected = {"create_procedure": { + "name": "demo", + "body": {"block": [ + {"declare": {"name": "done", "type": {"int": {}}, "default": False}}, + {"declare": {"name": "a", "type": {"char": 16}}}, + {"declare_cursor": { + "name": "cur1", + "query": {"select": [{"value": "id"}, {"value": "data"}], "from": "test.t1"}, + }}, + {"declare_cursor": {"name": "cur2", "query": {"select": {"value": "i"}, "from": "test.t2"}}}, + {"declare_handler": { + "action": "continue", + "conditions": "not_found", + "body": {"set": {"done": True}}, + }}, + {"open": "cur1"}, + {"open": "cur2"}, + { + "label": "read_loop", + "loop": [ + {"fetch": "cur1"}, + {"fetch": "cur2"}, + {"if": "done", "then": {"leave": "read_loop"}}, + { + "if": {"lt": ["b", "c"]}, + "then": {"query": {"select": [{"value": "a"}, {"value": "b"}]}, "insert": "test.t3"}, + "else": {"query": {"select": [{"value": "a"}, {"value": "c"}]}, "insert": "test.t3"}, + }, + ], + }, + {"close": "cur1"}, + {"close": "cur2"}, + ]}, + }} + self.assertEqual(result, expected) + + @skip("Not implemented yet") + def test_multiple_vars_in_declare(self): + sql = """ + CREATE PROCEDURE curdemo() + BEGIN + DECLARE b, c INT; + END; + """ + result = parse(sql) + expected = {"create_procedure": { + "name": "curdemo", + "body": {"block": [ + {"declare": {"name": "b", "type": {"int": {}}}}, + {"declare": {"name": "c", "type": {"int": {}}}}, + ]}, + }} + self.assertEqual(result, expected) + + def test_declare_cursor(self): + sql = """ + CREATE PROCEDURE demo() + BEGIN + DECLARE cur1 CURSOR FOR SELECT id,data FROM test.t1; + END; + """ + result = parse(sql) + expected = {"create_procedure": { + "name": "demo", + "body": { + "block": {"declare_cursor": { + "name": "cur1", + "query": {"select": [{"value": "id"}, {"value": "data"}], "from": "test.t1"}, + }}, + }, + }} + self.assertEqual(result, expected) diff --git a/tests/test_sqlglot.py b/tests/test_sqlglot.py index 69a2c02..dc3b59f 100644 --- a/tests/test_sqlglot.py +++ b/tests/test_sqlglot.py @@ -931,3 +931,8 @@ def test_issue_46_sqlglot_100(self): "where": {"eq": ["tbl_name.bar", 234]}, } self.assertEqual(result, expected) + def test_issue_46_sqlglot_101(self): + sql = """DROP temporary TABLE a""" + result = parse(sql) + expected = {"drop": { "temporary": True, "table": "a"}} + self.assertEqual(result, expected)