Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Token type #63

Merged
merged 1 commit into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 12 additions & 174 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,21 @@ mod token;
pub use token::TokenType;

/// Token value (lexeme)
pub struct Token(pub usize, pub Option<String>, pub usize);
#[derive(Clone, Copy)]
pub struct Token<'i>(pub usize, pub &'i [u8], pub usize);

pub(crate) fn sentinel(start: usize) -> Token {
Token(start, None, start)
pub(crate) fn sentinel(start: usize) -> Token<'static> {
Token(start, b"", start)
}

impl Token {
impl Token<'_> {
/// Access token value
pub fn unwrap(self) -> String {
self.1.unwrap()
}
/// Take token value
pub fn take(&mut self) -> Self {
Token(self.0, self.1.take(), self.2)
from_bytes(self.1)
}
}

impl std::fmt::Debug for Token {
impl std::fmt::Debug for Token<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Token").field(&self.1).finish()
}
Expand All @@ -35,31 +32,12 @@ impl TokenType {
// TODO try Cow<&'static, str> (Borrowed<&'static str> for keyword and Owned<String> for below),
// => Syntax error on keyword will be better
// => `from_token` will become unnecessary
pub(crate) fn to_token(self, start: usize, value: &[u8], end: usize) -> Token {
Token(
start,
match self {
TokenType::TK_CTIME_KW => Some(from_bytes(value)),
TokenType::TK_JOIN_KW => Some(from_bytes(value)),
TokenType::TK_LIKE_KW => Some(from_bytes(value)),
TokenType::TK_PTR => Some(from_bytes(value)),
// Identifiers
TokenType::TK_STRING => Some(from_bytes(value)),
TokenType::TK_ID => Some(from_bytes(value)),
TokenType::TK_VARIABLE => Some(from_bytes(value)),
// Values
TokenType::TK_ANY => Some(from_bytes(value)),
TokenType::TK_BLOB => Some(from_bytes(value)),
TokenType::TK_INTEGER => Some(from_bytes(value)),
TokenType::TK_FLOAT => Some(from_bytes(value)),
_ => None,
},
end,
)
pub(crate) fn to_token(self, start: usize, value: &[u8], end: usize) -> Token<'_> {
Token(start, value, end)
}
}

fn from_bytes(bytes: &[u8]) -> String {
pub(crate) fn from_bytes(bytes: &[u8]) -> String {
unsafe { str::from_utf8_unchecked(bytes).to_owned() }
}

Expand Down Expand Up @@ -97,148 +75,8 @@ pub(crate) fn is_identifier_continue(b: u8) -> bool {

// keyword may become an identifier
// see %fallback in parse.y
pub(crate) fn from_token(ty: u16, value: Token) -> String {
use TokenType::*;
if let Some(str) = value.1 {
return str;
}
match ty {
x if x == TK_ABORT as u16 => "ABORT".to_owned(),
x if x == TK_ACTION as u16 => "ACTION".to_owned(),
//x if x == TK_ADD as u16 => "ADD".to_owned(),
x if x == TK_AFTER as u16 => "AFTER".to_owned(),
//x if x == TK_ALL as u16 => "ALL".to_owned(),
//x if x == TK_ALTER as u16 => "ALTER".to_owned(),
x if x == TK_ALWAYS as u16 => "ALWAYS".to_owned(),
x if x == TK_ANALYZE as u16 => "ANALYZE".to_owned(),
//x if x == TK_AND as u16 => "AND".to_owned(),
//x if x == TK_AS as u16 => "AS".to_owned(),
x if x == TK_ASC as u16 => "ASC".to_owned(),
x if x == TK_ATTACH as u16 => "ATTACH".to_owned(),
//x if x == TK_AUTOINCR as u16 => "AUTOINCREMENT".to_owned(),
x if x == TK_BEFORE as u16 => "BEFORE".to_owned(),
x if x == TK_BEGIN as u16 => "BEGIN".to_owned(),
//x if x == TK_BETWEEN as u16 => "BETWEEN".to_owned(),
x if x == TK_BY as u16 => "BY".to_owned(),
x if x == TK_CASCADE as u16 => "CASCADE".to_owned(),
//x if x == TK_CASE as u16 => "CASE".to_owned(),
x if x == TK_CAST as u16 => "CAST".to_owned(),
//x if x == TK_CHECK as u16 => "CHECK".to_owned(),
//x if x == TK_COLLATE as u16 => "COLLATE".to_owned(),
x if x == TK_COLUMNKW as u16 => "COLUMN".to_owned(),
//x if x == TK_COMMIT as u16 => "COMMIT".to_owned(),
x if x == TK_CONFLICT as u16 => "CONFLICT".to_owned(),
//x if x == TK_CONSTRAINT as u16 => "CONSTRAINT".to_owned(),
//x if x == TK_CREATE as u16 => "CREATE".to_owned(),
x if x == TK_CURRENT as u16 => "CURRENT".to_owned(),
x if x == TK_DATABASE as u16 => "DATABASE".to_owned(),
x if x == TK_DEFAULT as u16 => "DEFAULT".to_owned(),
//x if x == TK_DEFERRABLE as u16 => "DEFERRABLE".to_owned(),
x if x == TK_DEFERRED as u16 => "DEFERRED".to_owned(),
x if x == TK_DELETE as u16 => "DELETE".to_owned(),
x if x == TK_DESC as u16 => "DESC".to_owned(),
x if x == TK_DETACH as u16 => "DETACH".to_owned(),
//x if x == TK_DISTINCT as u16 => "DISTINCT".to_owned(),
x if x == TK_DO as u16 => "DO".to_owned(),
//x if x == TK_DROP as u16 => "DROP".to_owned(),
x if x == TK_EACH as u16 => "EACH".to_owned(),
//x if x == TK_ELSE as u16 => "ELSE".to_owned(),
x if x == TK_END as u16 => "END".to_owned(),
//x if x == TK_ESCAPE as u16 => "ESCAPE".to_owned(),
//x if x == TK_EXCEPT as u16 => "EXCEPT".to_owned(),
x if x == TK_EXCLUDE as u16 => "EXCLUDE".to_owned(),
x if x == TK_EXCLUSIVE as u16 => "EXCLUSIVE".to_owned(),
//x if x == TK_EXISTS as u16 => "EXISTS".to_owned(),
x if x == TK_EXPLAIN as u16 => "EXPLAIN".to_owned(),
x if x == TK_FAIL as u16 => "FAIL".to_owned(),
//x if x == TK_FILTER as u16 => "FILTER".to_owned(),
x if x == TK_FIRST as u16 => "FIRST".to_owned(),
x if x == TK_FOLLOWING as u16 => "FOLLOWING".to_owned(),
x if x == TK_FOR as u16 => "FOR".to_owned(),
//x if x == TK_FOREIGN as u16 => "FOREIGN".to_owned(),
//x if x == TK_FROM as u16 => "FROM".to_owned(),
x if x == TK_GENERATED as u16 => "GENERATED".to_owned(),
//x if x == TK_GROUP as u16 => "GROUP".to_owned(),
x if x == TK_GROUPS as u16 => "GROUPS".to_owned(),
//x if x == TK_HAVING as u16 => "HAVING".to_owned(),
x if x == TK_IF as u16 => "IF".to_owned(),
x if x == TK_IGNORE as u16 => "IGNORE".to_owned(),
x if x == TK_IMMEDIATE as u16 => "IMMEDIATE".to_owned(),
//x if x == TK_IN as u16 => "IN".to_owned(),
//x if x == TK_INDEX as u16 => "INDEX".to_owned(),
x if x == TK_INDEXED as u16 => "INDEXED".to_owned(),
x if x == TK_INITIALLY as u16 => "INITIALLY".to_owned(),
//x if x == TK_INSERT as u16 => "INSERT".to_owned(),
x if x == TK_INSTEAD as u16 => "INSTEAD".to_owned(),
//x if x == TK_INTERSECT as u16 => "INTERSECT".to_owned(),
//x if x == TK_INTO as u16 => "INTO".to_owned(),
//x if x == TK_IS as u16 => "IS".to_owned(),
//x if x == TK_ISNULL as u16 => "ISNULL".to_owned(),
//x if x == TK_JOIN as u16 => "JOIN".to_owned(),
x if x == TK_KEY as u16 => "KEY".to_owned(),
x if x == TK_LAST as u16 => "LAST".to_owned(),
//x if x == TK_LIMIT as u16 => "LIMIT".to_owned(),
x if x == TK_MATCH as u16 => "MATCH".to_owned(),
x if x == TK_MATERIALIZED as u16 => "MATERIALIZED".to_owned(),
x if x == TK_NO as u16 => "NO".to_owned(),
//x if x == TK_NOT as u16 => "NOT".to_owned(),
//x if x == TK_NOTHING as u16 => "NOTHING".to_owned(),
//x if x == TK_NOTNULL as u16 => "NOTNULL".to_owned(),
//x if x == TK_NULL as u16 => "NULL".to_owned(),
x if x == TK_NULLS as u16 => "NULLS".to_owned(),
x if x == TK_OF as u16 => "OF".to_owned(),
x if x == TK_OFFSET as u16 => "OFFSET".to_owned(),
x if x == TK_ON as u16 => "ON".to_owned(),
//x if x == TK_OR as u16 => "OR".to_owned(),
//x if x == TK_ORDER as u16 => "ORDER".to_owned(),
x if x == TK_OTHERS as u16 => "OTHERS".to_owned(),
//x if x == TK_OVER as u16 => "OVER".to_owned(),
x if x == TK_PARTITION as u16 => "PARTITION".to_owned(),
x if x == TK_PLAN as u16 => "PLAN".to_owned(),
x if x == TK_PRAGMA as u16 => "PRAGMA".to_owned(),
x if x == TK_PRECEDING as u16 => "PRECEDING".to_owned(),
//x if x == TK_PRIMARY as u16 => "PRIMARY".to_owned(),
x if x == TK_QUERY as u16 => "QUERY".to_owned(),
x if x == TK_RAISE as u16 => "RAISE".to_owned(),
x if x == TK_RANGE as u16 => "RANGE".to_owned(),
x if x == TK_RECURSIVE as u16 => "RECURSIVE".to_owned(),
//x if x == TK_REFERENCES as u16 => "REFERENCES".to_owned(),
x if x == TK_REINDEX as u16 => "REINDEX".to_owned(),
x if x == TK_RELEASE as u16 => "RELEASE".to_owned(),
x if x == TK_RENAME as u16 => "RENAME".to_owned(),
x if x == TK_REPLACE as u16 => "REPLACE".to_owned(),
//x if x == TK_RETURNING as u16 => "RETURNING".to_owned(),
x if x == TK_RESTRICT as u16 => "RESTRICT".to_owned(),
x if x == TK_ROLLBACK as u16 => "ROLLBACK".to_owned(),
x if x == TK_ROW as u16 => "ROW".to_owned(),
x if x == TK_ROWS as u16 => "ROWS".to_owned(),
x if x == TK_SAVEPOINT as u16 => "SAVEPOINT".to_owned(),
//x if x == TK_SELECT as u16 => "SELECT".to_owned(),
//x if x == TK_SET as u16 => "SET".to_owned(),
//x if x == TK_TABLE as u16 => "TABLE".to_owned(),
x if x == TK_TEMP as u16 => "TEMP".to_owned(),
//x if x == TK_TEMP as u16 => "TEMPORARY".to_owned(),
//x if x == TK_THEN as u16 => "THEN".to_owned(),
x if x == TK_TIES as u16 => "TIES".to_owned(),
//x if x == TK_TO as u16 => "TO".to_owned(),
//x if x == TK_TRANSACTION as u16 => "TRANSACTION".to_owned(),
x if x == TK_TRIGGER as u16 => "TRIGGER".to_owned(),
x if x == TK_UNBOUNDED as u16 => "UNBOUNDED".to_owned(),
//x if x == TK_UNION as u16 => "UNION".to_owned(),
//x if x == TK_UNIQUE as u16 => "UNIQUE".to_owned(),
//x if x == TK_UPDATE as u16 => "UPDATE".to_owned(),
//x if x == TK_USING as u16 => "USING".to_owned(),
x if x == TK_VACUUM as u16 => "VACUUM".to_owned(),
x if x == TK_VALUES as u16 => "VALUES".to_owned(),
x if x == TK_VIEW as u16 => "VIEW".to_owned(),
x if x == TK_VIRTUAL as u16 => "VIRTUAL".to_owned(),
//x if x == TK_WHEN as u16 => "WHEN".to_owned(),
//x if x == TK_WHERE as u16 => "WHERE".to_owned(),
//x if x == TK_WINDOW as u16 => "WINDOW".to_owned(),
x if x == TK_WITH as u16 => "WITH".to_owned(),
x if x == TK_WITHOUT as u16 => "WITHOUT".to_owned(),
_ => unreachable!(),
}
pub(crate) fn from_token(_ty: u16, value: Token) -> String {
from_bytes(value.1)
}

impl TokenType {
Expand Down
5 changes: 1 addition & 4 deletions src/lexer/sql/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ fn duplicate_column() {
fn create_table_without_column() {
expect_parser_err(
b"CREATE TABLE t ()",
ParserError::SyntaxError {
token_type: "RP",
found: None,
},
ParserError::SyntaxError(")".to_owned()),
);
}

Expand Down
72 changes: 33 additions & 39 deletions src/parser/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod fmt;

use std::num::ParseIntError;
use std::ops::Deref;
use std::str::{Bytes, FromStr};
use std::str::{self, Bytes, FromStr};

use fmt::{ToTokens, TokenStream};
use indexmap::{IndexMap, IndexSet};
Expand Down Expand Up @@ -418,10 +418,8 @@ impl Expr {
/// Constructor
pub fn ptr(left: Expr, op: Token, right: Expr) -> Expr {
let mut ptr = Operator::ArrowRight;
if let Some(ref op) = op.1 {
if op == "->>" {
ptr = Operator::ArrowRightShift;
}
if op.1 == b"->>" {
ptr = Operator::ArrowRightShift;
}
Expr::Binary(Box::new(left), ptr, Box::new(right))
}
Expand Down Expand Up @@ -515,16 +513,12 @@ pub enum Literal {
impl Literal {
/// Constructor
pub fn from_ctime_kw(token: Token) -> Literal {
if let Some(ref token) = token.1 {
if "CURRENT_DATE".eq_ignore_ascii_case(token) {
Literal::CurrentDate
} else if "CURRENT_TIME".eq_ignore_ascii_case(token) {
Literal::CurrentTime
} else if "CURRENT_TIMESTAMP".eq_ignore_ascii_case(token) {
Literal::CurrentTimestamp
} else {
unreachable!()
}
if b"CURRENT_DATE".eq_ignore_ascii_case(token.1) {
Literal::CurrentDate
} else if b"CURRENT_TIME".eq_ignore_ascii_case(token.1) {
Literal::CurrentTime
} else if b"CURRENT_TIMESTAMP".eq_ignore_ascii_case(token.1) {
Literal::CurrentTimestamp
} else {
unreachable!()
}
Expand All @@ -550,14 +544,13 @@ impl LikeOperator {
if token_type == TK_MATCH as YYCODETYPE {
return LikeOperator::Match;
} else if token_type == TK_LIKE_KW as YYCODETYPE {
if let Some(ref token) = token.1 {
if "LIKE".eq_ignore_ascii_case(token) {
return LikeOperator::Like;
} else if "GLOB".eq_ignore_ascii_case(token) {
return LikeOperator::Glob;
} else if "REGEXP".eq_ignore_ascii_case(token) {
return LikeOperator::Regexp;
}
let token = token.1;
if b"LIKE".eq_ignore_ascii_case(token) {
return LikeOperator::Like;
} else if b"GLOB".eq_ignore_ascii_case(token) {
return LikeOperator::Glob;
} else if b"REGEXP".eq_ignore_ascii_case(token) {
return LikeOperator::Regexp;
}
}
unreachable!()
Expand Down Expand Up @@ -887,24 +880,22 @@ impl JoinOperator {
n1: Option<Name>,
n2: Option<Name>,
) -> Result<JoinOperator, ParserError> {
Ok(if let Some(ref t) = token.1 {
let mut jt = JoinType::try_from(t.as_ref())?;
Ok({
let mut jt = JoinType::try_from(token.1)?;
for n in [&n1, &n2].into_iter().flatten() {
jt |= JoinType::try_from(n.0.as_ref())?;
}
if (jt & (JoinType::INNER | JoinType::OUTER)) == (JoinType::INNER | JoinType::OUTER)
|| (jt & (JoinType::OUTER | JoinType::LEFT | JoinType::RIGHT)) == JoinType::OUTER
{
return Err(custom_err!(
"unsupported JOIN type: {} {:?} {:?}",
t,
"unsupported JOIN type: {:?} {:?} {:?}",
str::from_utf8(token.1),
n1,
n2
));
}
JoinOperator::TypedJoin(Some(jt))
} else {
unreachable!()
})
}
fn is_natural(&self) -> bool {
Expand Down Expand Up @@ -935,25 +926,28 @@ bitflags::bitflags! {
}
}

impl TryFrom<&str> for JoinType {
impl TryFrom<&[u8]> for JoinType {
type Error = ParserError;
fn try_from(s: &str) -> Result<JoinType, ParserError> {
if "CROSS".eq_ignore_ascii_case(s) {
fn try_from(s: &[u8]) -> Result<JoinType, ParserError> {
if b"CROSS".eq_ignore_ascii_case(s) {
Ok(JoinType::INNER | JoinType::CROSS)
} else if "FULL".eq_ignore_ascii_case(s) {
} else if b"FULL".eq_ignore_ascii_case(s) {
Ok(JoinType::LEFT | JoinType::RIGHT | JoinType::OUTER)
} else if "INNER".eq_ignore_ascii_case(s) {
} else if b"INNER".eq_ignore_ascii_case(s) {
Ok(JoinType::INNER)
} else if "LEFT".eq_ignore_ascii_case(s) {
} else if b"LEFT".eq_ignore_ascii_case(s) {
Ok(JoinType::LEFT | JoinType::OUTER)
} else if "NATURAL".eq_ignore_ascii_case(s) {
} else if b"NATURAL".eq_ignore_ascii_case(s) {
Ok(JoinType::NATURAL)
} else if "RIGHT".eq_ignore_ascii_case(s) {
} else if b"RIGHT".eq_ignore_ascii_case(s) {
Ok(JoinType::RIGHT | JoinType::OUTER)
} else if "OUTER".eq_ignore_ascii_case(s) {
} else if b"OUTER".eq_ignore_ascii_case(s) {
Ok(JoinType::OUTER)
} else {
Err(custom_err!("unsupported JOIN type: {}", s))
Err(custom_err!(
"unsupported JOIN type: {:?}",
str::from_utf8(s)
))
}
}
}
Expand Down
11 changes: 3 additions & 8 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ use ast::{Cmd, ExplainKind, Name, Stmt};
#[derive(Debug, PartialEq)]
pub enum ParserError {
/// Syntax error
SyntaxError {
/// token type
token_type: &'static str,
/// token value
found: Option<String>,
},
SyntaxError(String),
/// Unexpected EOF
UnexpectedEof,
/// Custom error
Expand All @@ -36,8 +31,8 @@ pub enum ParserError {
impl std::fmt::Display for ParserError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
ParserError::SyntaxError { token_type, found } => {
write!(f, "near {}, \"{:?}\": syntax error", token_type, found)
ParserError::SyntaxError(s) => {
write!(f, "near \"{}\": syntax error", s)
}
ParserError::UnexpectedEof => f.write_str("unexpected end of input"),
ParserError::Custom(s) => f.write_str(s),
Expand Down
Loading