diff --git a/mo_sql_parsing/sql_parser.py b/mo_sql_parsing/sql_parser.py index aa28fb1..40c3eaa 100644 --- a/mo_sql_parsing/sql_parser.py +++ b/mo_sql_parsing/sql_parser.py @@ -61,7 +61,7 @@ def parser(literal_string, simple_ident, all_columns=None, sqlserver=False): # EXPRESSIONS expression = Forward() - (column_type, column_definition, column_def_references, column_option,) = get_column_type( + (column_type, column_definition, column_def_references, column_option, declare_variable) = get_column_type( expression, identifier, literal_string ) proc_param = Group( @@ -912,7 +912,6 @@ def mult(tokens): ############################################################# statement = Forward() special_ident = keyword("masking policy") | identifier / (lambda t: t[0].lower()) - declare_variable = assign("declare", column_definition) set_one_variable = SET + ( (special_ident + Optional(EQ) + expression) / (lambda t: {t[0].lower(): t[1].lower() if isinstance(t[1], str) else t[1]}) diff --git a/mo_sql_parsing/types.py b/mo_sql_parsing/types.py index cd9b7a3..b41b947 100644 --- a/mo_sql_parsing/types.py +++ b/mo_sql_parsing/types.py @@ -21,7 +21,7 @@ RIGHT_ASSOC, LEFT_ASSOC, Keyword, - Combine, + Combine, Empty, ) from mo_sql_parsing.keywords import ( @@ -266,13 +266,20 @@ def get_column_type(expr, identifier, literal_string): | assign("check", LB + expr + RB) | assign("default", expr) | assign("on update", expr) - | (EQ + expr)("default") ) column_definition << Group( identifier("name") + (column_type | identifier("type")) + ZeroOrMore(column_options) ) / to_flat_column_type + variable_options = ( + assign("default", expr) + | EQ + expr("default") + ) + declare_variable = assign("declare", delimited_list(Group( + identifier("name") + Optional(AS) + simple_types("type") + ZeroOrMore(variable_options) + ))) + set_parser_names() - return column_type, column_definition, column_def_references, column_options + return column_type, column_definition, column_def_references, column_options, declare_variable diff --git a/tests/test_sql_server.py b/tests/test_sql_server.py index 4780756..be6d5b2 100644 --- a/tests/test_sql_server.py +++ b/tests/test_sql_server.py @@ -9,8 +9,6 @@ from unittest import TestCase -from mo_parsing.debug import Debugger - from mo_sql_parsing import parse_sqlserver as parse @@ -307,7 +305,7 @@ def test_issue_237_declare_var(self): result = parse(sql) expected = {"create_procedure": { "name": "k", - "body": {"block": {"declare": {"default": 42, "name": "@MYVARIABLE", "type": {"int": {}},}}}, + "body": {"block": {"declare": {"default": 42, "name": "@MYVARIABLE", "type": {"int": {}}, }}}, }} self.assertEqual(result, expected) @@ -318,7 +316,7 @@ def test_issue_237_declare_var_with_as(self): result = parse(sql) expected = {"create_procedure": { "name": "k", - "body": {"block": {"declare": {"default": 42, "name": "@MYVARIABLE", "type": {"int": {}},}}}, + "body": {"block": {"declare": {"default": 42, "name": "@MYVARIABLE", "type": {"int": {}}}}}, }} self.assertEqual(result, expected) @@ -327,12 +325,26 @@ def test_issue_237_declare_multiple_vars(self): DECLARE @MYVAR INT, @MYOTHERVAR DATE; END""" result = parse(sql) - expected = { - "create_procedure": { + expected = {"create_procedure": { + "name": "k", + "body": {"block": {"declare": [ + {"name": "@MYVAR", "type": {"int": {}}}, + {"name": "@MYOTHERVAR", "type": {"date": {}}}, + ]}}, + }} + self.assertEqual(result, expected) + + def test_issue_237_declare_multiple_vars_with_code(self): + sql = """create procedure k() BEGIN + DECLARE @MYVAR INT, @MYOTHERVAR DATE; + SELECT a FROM rental; + END""" + result = parse(sql) + expected = {"create_procedure": { "name": "k", "body": {"block": [ - {"declare": {"name": "@MYFIRSTVAR", "type": {"int": {}}}}, - {"declare": {"name": "@MYOTHERVAR", "type": {"int": {}}}}, + {"declare": [{"name": "@MYVAR", "type": {"int": {}}}, {"name": "@MYOTHERVAR", "type": {"date": {}}}]}, + {"select": {"value": "a"}, "from": "rental"}, ]}, }} self.assertEqual(result, expected)