Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(parser):fix cast error #95

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/optimizer/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def visit_long_literal(self, node, unmangle_names):
def visit_double_literal(self, node, unmangle_names):
return str(node.value)

def visit_field_type(self, node, unmangle_names):
return str(node)

def visit_time_literal(self, node, unmangle_names):
return "TIME '%s'" % node.value

Expand Down
7 changes: 4 additions & 3 deletions src/parser/mysql_parser/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
'SLASH',
'ASTERISK',
'NON_RESERVED',
'NUMBER',
'QM',
'SCONST',
]
Expand Down Expand Up @@ -99,13 +100,13 @@ def t_DOUBLE(t):
if 'e' in t.value or 'E' in t.value or '.' in t.value:
t.type = 'DOUBLE'
else:
t.type = "INTEGER"
t.type = "NUMBER"
return t


def t_INTEGER(t):
def t_NUMBER(t):
r'\d+'
t.type = "INTEGER"
t.type = "NUMBER"
return t


Expand Down
158 changes: 118 additions & 40 deletions src/parser/mysql_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from src.parser.tree.statement import Delete, Insert, Query, Update
from src.parser.tree.table import Table, TableSubquery
from src.parser.tree.values import Values
from src.parser.tree.field_type import UNSPECIFIEDLENGTH, FieldType, MySQLType

from ply import yacc
from src.optimizer.optimizer_enum import IndexType
Expand Down Expand Up @@ -125,7 +126,7 @@ def p_create_table_end(p):
r"""create_table_end : ENGINE EQ identifier create_table_end
| DEFAULT CHARSET EQ identifier create_table_end
| COLLATE EQ identifier create_table_end
| AUTO_INCREMENT EQ integer create_table_end
| AUTO_INCREMENT EQ number create_table_end
| COMMENT EQ SCONST create_table_end
| COMPRESSION EQ SCONST create_table_end
| empty
Expand Down Expand Up @@ -160,17 +161,17 @@ def p_column(p):
def p_column_type(p):
r"""
column_type : INT column_end
| INT LPAREN integer RPAREN column_end
| INT LPAREN number RPAREN column_end
| FLOAT column_end
| BIGINT column_end
| BIGINT LPAREN integer RPAREN column_end
| TINYINT LPAREN integer RPAREN column_end
| BIGINT LPAREN number RPAREN column_end
| TINYINT LPAREN number RPAREN column_end
| DATETIME column_end
| DATETIME LPAREN integer RPAREN column_end
| VARCHAR LPAREN integer RPAREN column_end
| CHAR LPAREN integer RPAREN column_end
| DATETIME LPAREN number RPAREN column_end
| VARCHAR LPAREN number RPAREN column_end
| CHAR LPAREN number RPAREN column_end
| TIMESTAMP column_end
| DECIMAL LPAREN integer COMMA integer RPAREN column_end
| DECIMAL LPAREN number COMMA number RPAREN column_end
"""
p[0] = p[1].lower()

Expand Down Expand Up @@ -412,7 +413,7 @@ def p_subquery(p):
def p_for_update_opt(p):
r"""for_update_opt : FOR UPDATE
| FOR UPDATE NOWAIT
| FOR UPDATE WAIT integer
| FOR UPDATE WAIT number
| LOCK IN SHARE MODE
| empty"""
if len(p) == 3:
Expand Down Expand Up @@ -484,12 +485,12 @@ def p_null_ordering_opt(p):

# LIMIT
def p_limit_opt(p):
r"""limit_opt : LIMIT integer
| LIMIT integer COMMA integer
r"""limit_opt : LIMIT number
| LIMIT number COMMA number
| LIMIT QM
| LIMIT QM COMMA QM
| LIMIT ALL
| LIMIT integer OFFSET integer
| LIMIT number OFFSET number
| empty"""
if len(p) < 5:
p[0] = (0, p[2]) if p[1] else None
Expand All @@ -500,8 +501,8 @@ def p_limit_opt(p):
p[0] = (p[4], p[2])


def p_integer(p):
r"""integer : INTEGER"""
def p_number(p):
r"""number : NUMBER"""
p[0] = p[1]


Expand Down Expand Up @@ -1083,7 +1084,7 @@ def p_base_primary_expression(p):
def p_value(p):
r"""value : NULL
| SCONST
| number
| figure
| boolean_value
| QUOTED_IDENTIFIER
| QM"""
Expand Down Expand Up @@ -1149,7 +1150,7 @@ def p_searched_case(p):


def p_cast_specification(p):
r"""cast_specification : CAST LPAREN value_expression AS data_type RPAREN"""
r"""cast_specification : CAST LPAREN expression AS cast_field RPAREN"""
p[0] = Cast(p.lineno(1), p.lexpos(1), expression=p[3], data_type=p[5], safe=False)


Expand Down Expand Up @@ -1182,45 +1183,122 @@ def p_call_list(p):
_item_list(p)


def p_data_type(p):
r"""data_type : base_data_type type_param_list_opt"""
signature = p[1]
if p[2]:
# Normalize param list
type_params = [str(_type) for _type in p[2]]
signature += "(" + ','.join(type_params) + ")"
p[0] = signature
def p_cast_field(p):
r"""cast_field : BINARY field_len_opt
| char_type field_len_opt field_param_list_opt
| DATE
| YEAR
| DATETIME field_len_opt
| DECIMAL float_opt
| TIME field_len_opt
| SIGNED integer_opt
| UNSIGNED integer_opt
| JSON
| DOUBLE
| FLOAT float_opt
| REAL"""
field = FieldType(p.lineno(1), p.lexpos(1))
if p.slice[1].type == "BINARY":
field.set_tp(MySQLType.BINARY, "BINARY")
field.set_length(p[2])
elif p.slice[1].type == "char_type":
field.set_tp(MySQLType.CHAR, p[1])
field.set_length(p[2])
if p[3] != None:
field.set_charset_and_collation(f"({','.join(p[3])})")
elif p.slice[1].type == "DATE":
field.set_tp(MySQLType.DATE, "DATE")
elif p.slice[1].type == "YEAR":
field.set_tp(MySQLType.YEAR, "YEAR")
elif p.slice[1].type == 'DATETIME':
field.set_tp(MySQLType.DATETIME, "DATETIME")
field.set_length(p[2])
elif p.slice[1].type == 'DECIMAL':
field.set_tp(MySQLType.DECIMAL, "DEMCIMAL")
field.set_length(p[2].length)
field.set_decimal(p[2].decimal)
elif p.slice[1].type == "TIME":
field.set_tp(MySQLType.TIME, "TIME")
field.set_length(p[2])
elif p.slice[1].type == "SIGNED":
field.set_tp(MySQLType.INTEGER, p[2])
elif p.slice[1].type == "UNSIGNED":
field.set_tp(MySQLType.INTEGER, p[2])
elif p.slice[1].type == "JSON":
field.set_tp(MySQLType.JSON, "JSON")
elif p.slice[1].type == "DOUBLE":
field.set_tp(MySQLType.DOUBLE, "DOUBLE")
elif p.slice[1].type == "FLOAT":
field.set_tp(MySQLType.FLOAT, "FLOAT")
field.set_length(p[2].length)
field.set_decimal(p[2].decimal)
elif p.slice[1].type == "REAL":
field.set_tp(MySQLType.REAL, "REAL")


def p_field_len_opt(p):
r"""field_len_opt : LPAREN NUMBER RPAREN
| empty"""
if len(p) == 4:
p[0] = p[2].value & 0xFFFFFFFF # convert to unsigned int
p[0] = UNSPECIFIEDLENGTH


def p_type_param_list_opt(p):
r"""type_param_list_opt : LPAREN type_param_list RPAREN
def p_field_param_list_opt(p):
r"""field_param_list_opt : LPAREN field_param_list RPAREN
| empty"""
p[0] = p[2] if p[1] else p[1]


def p_type_param_list(p):
r"""type_param_list : type_param_list COMMA type_parameter
| type_parameter"""
def p_field_param_list(p):
r"""field_param_list : field_param_list COMMA field_parameter
| field_parameter"""
_item_list(p)


def p_type_parameter(p):
r"""type_parameter : integer
def p_field_parameter(p):
r"""field_parameter : number
| base_data_type"""
p[0] = p[1]


def p_float_opt(p):
r"""float_opt : LPAREN NUMBER RPAREN
| LPAREN NUMBER COMMA NUMBER RPAREN
| empty"""
# First is length,Second is decimal
if len(p) == 2:
p[0] = {'length': UNSPECIFIEDLENGTH, 'decimal': UNSPECIFIEDLENGTH}
elif len(p) == 4:
p[0] = {'length': p[2], 'decimal': UNSPECIFIEDLENGTH}
elif len(p) == 6:
p[0] = {'length': p[2], 'decimal': p[4]}


def p_char_type(p):
r"""char_type : CHARACTER
| CHAR"""
p[0] = p[1]


def p_integer_opt(p):
r"""integer_opt : INTEGER
| INT
| empty"""
p[0] = p[1]


def p_base_data_type(p):
r"""base_data_type : identifier"""
p[0] = p[1]


def p_date_time(p):
r"""date_time : CURRENT_DATE
| CURRENT_TIME integer_param_opt
| CURRENT_TIMESTAMP integer_param_opt
| LOCALTIME integer_param_opt
| LOCALTIMESTAMP integer_param_opt"""
| CURRENT_TIME number_param_opt
| CURRENT_TIMESTAMP number_param_opt
| LOCALTIME number_param_opt
| LOCALTIMESTAMP number_param_opt"""
precision = p[2] if len(p) == 3 else None
p[0] = CurrentTime(p.lineno(1), p.lexpos(1), type=p[1], precision=precision)

Expand All @@ -1247,8 +1325,8 @@ def p_boolean_value(p):
p[0] = BooleanLiteral(p.lineno(1), p.lexpos(1), value=p[1])


def p_integer_param_opt(p):
"""integer_param_opt : LPAREN integer RPAREN
def p_number_param_opt(p):
"""number_param_opt : LPAREN number RPAREN
| LPAREN RPAREN
| empty"""
p[0] = int(p[2]) if len(p) == 4 else None
Expand Down Expand Up @@ -1598,9 +1676,9 @@ def p_quoted_identifier(p):
p[0] = p[1][1:-1]


def p_number(p):
r"""number : DOUBLE
| integer"""
def p_figure(p):
r"""figure : DOUBLE
| NUMBER"""
if p.slice[1].type == "DOUBLE":
p[0] = DoubleLiteral(p.lineno(1), p.lexpos(1), p[1])
else:
Expand Down
6 changes: 0 additions & 6 deletions src/parser/tree/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,6 @@ def __init__(self, line=None, pos=None, type=None, left=None, right=None):
def accept(self, visitor, context):
return visitor.visit_logical_binary_expression(self, context)

# def and_op(self, left, right):
# return type == "AND"
#
# def or_op(self, left, right):
# return type == "OR"


class CoalesceExpression(Expression):
def __init__(self, line=None, pos=None, operands=None):
Expand Down
86 changes: 86 additions & 0 deletions src/parser/tree/field_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from .node import Node

UNSPECIFIEDLENGTH = -1


class MySQLType:
BINARY = 1
CHAR = 2
DATE = 3
YEAR = 4
DATETIME = 5
DECIMAL = 6
TIME = 7
INTEGER = 8
JSON = 9
DOUBLE = 10
FLOAT = 11
REAL = 12


class FieldType(Node):
def __init__(
self,
line=None,
pos=None,
) -> None:
super(FieldType, self).__init__(line, pos)

def set_tp(self, tp, type_name):
self.tp = tp
self.type_name = type_name

def set_length(self, length):
self.length = length

def set_decimal(self, decimal):
self.decimal = decimal

def set_flag(self, flag):
self.flag = flag

def set_is_signed(self, is_signed):
self.is_signed = is_signed

def set_charset_and_collation(self, charset_and_collation):
self.charset_and_collation = charset_and_collation

def accept(self, visitor, context):
return visitor.visit_field_type(self, context)

def __str__(self):
result = ""
if self.type_name != None:
result = self.type_name

if (
self.tp is MySQLType.BINARY
or self.tp is MySQLType.CHAR
or self.tp is MySQLType.TIME
or self.tp is MySQLType.DATETIME
):
if "length" in dir(FieldType) and self.length != UNSPECIFIEDLENGTH:
result += f" ({self.length})"

if self.tp is MySQLType.CHAR:
if (
"charset_and_collation" in dir(FieldType)
and self.charset_and_collation != None
):
result += f" ({self.charset_and_collation})"

if self.tp is MySQLType.INTEGER:
result = "SIGNED " + result if self.is_signed else "UNSIGNED " + result

if self.tp is MySQLType.FLOAT or self.tp is MySQLType.DECIMAL:
if (
"length" in dir(FieldType)
and self.length != UNSPECIFIEDLENGTH
and "decimal" in dir(FieldType)
and self.decimal != UNSPECIFIEDLENGTH
):
result += f" ({self.length},{self.decimal})"
elif "length" in dir(FieldType) and self.length != UNSPECIFIEDLENGTH:
result += f" ({self.length})"

return result
Loading
Loading