Skip to content

Commit

Permalink
fix more declares
Browse files Browse the repository at this point in the history
  • Loading branch information
klahnakoski committed May 31, 2024
1 parent 7866a00 commit fe45f61
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
3 changes: 1 addition & 2 deletions mo_sql_parsing/sql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parser(literal_string, simple_ident, all_columns=None, sqlserver=False):

# EXPRESSIONS
expression = Forward()
(column_type, column_definition, column_def_references, column_option,) = get_column_type(
(column_type, column_definition, column_def_references, column_option, declare_variable) = get_column_type(
expression, identifier, literal_string
)
proc_param = Group(
Expand Down Expand Up @@ -912,7 +912,6 @@ def mult(tokens):
#############################################################
statement = Forward()
special_ident = keyword("masking policy") | identifier / (lambda t: t[0].lower())
declare_variable = assign("declare", column_definition)
set_one_variable = SET + (
(special_ident + Optional(EQ) + expression)
/ (lambda t: {t[0].lower(): t[1].lower() if isinstance(t[1], str) else t[1]})
Expand Down
13 changes: 10 additions & 3 deletions mo_sql_parsing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
RIGHT_ASSOC,
LEFT_ASSOC,
Keyword,
Combine,
Combine, Empty,
)

from mo_sql_parsing.keywords import (
Expand Down Expand Up @@ -266,13 +266,20 @@ def get_column_type(expr, identifier, literal_string):
| assign("check", LB + expr + RB)
| assign("default", expr)
| assign("on update", expr)
| (EQ + expr)("default")
)

column_definition << Group(
identifier("name") + (column_type | identifier("type")) + ZeroOrMore(column_options)
) / to_flat_column_type

variable_options = (
assign("default", expr)
| EQ + expr("default")
)
declare_variable = assign("declare", delimited_list(Group(
identifier("name") + Optional(AS) + simple_types("type") + ZeroOrMore(variable_options)
)))

set_parser_names()

return column_type, column_definition, column_def_references, column_options
return column_type, column_definition, column_def_references, column_options, declare_variable
28 changes: 20 additions & 8 deletions tests/test_sql_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

from unittest import TestCase

from mo_parsing.debug import Debugger

from mo_sql_parsing import parse_sqlserver as parse


Expand Down Expand Up @@ -307,7 +305,7 @@ def test_issue_237_declare_var(self):
result = parse(sql)
expected = {"create_procedure": {
"name": "k",
"body": {"block": {"declare": {"default": 42, "name": "@MYVARIABLE", "type": {"int": {}},}}},
"body": {"block": {"declare": {"default": 42, "name": "@MYVARIABLE", "type": {"int": {}}, }}},
}}
self.assertEqual(result, expected)

Expand All @@ -318,7 +316,7 @@ def test_issue_237_declare_var_with_as(self):
result = parse(sql)
expected = {"create_procedure": {
"name": "k",
"body": {"block": {"declare": {"default": 42, "name": "@MYVARIABLE", "type": {"int": {}},}}},
"body": {"block": {"declare": {"default": 42, "name": "@MYVARIABLE", "type": {"int": {}}}}},
}}
self.assertEqual(result, expected)

Expand All @@ -327,12 +325,26 @@ def test_issue_237_declare_multiple_vars(self):
DECLARE @MYVAR INT, @MYOTHERVAR DATE;
END"""
result = parse(sql)
expected = {
"create_procedure": {
expected = {"create_procedure": {
"name": "k",
"body": {"block": {"declare": [
{"name": "@MYVAR", "type": {"int": {}}},
{"name": "@MYOTHERVAR", "type": {"date": {}}},
]}},
}}
self.assertEqual(result, expected)

def test_issue_237_declare_multiple_vars_with_code(self):
sql = """create procedure k() BEGIN
DECLARE @MYVAR INT, @MYOTHERVAR DATE;
SELECT a FROM rental;
END"""
result = parse(sql)
expected = {"create_procedure": {
"name": "k",
"body": {"block": [
{"declare": {"name": "@MYFIRSTVAR", "type": {"int": {}}}},
{"declare": {"name": "@MYOTHERVAR", "type": {"int": {}}}},
{"declare": [{"name": "@MYVAR", "type": {"int": {}}}, {"name": "@MYOTHERVAR", "type": {"date": {}}}]},
{"select": {"value": "a"}, "from": "rental"},
]},
}}
self.assertEqual(result, expected)

0 comments on commit fe45f61

Please sign in to comment.