Skip to content

Commit

Permalink
Fix JOIN parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
gwenn committed Dec 3, 2023
1 parent 82cc079 commit a26f13e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 171 deletions.
243 changes: 76 additions & 167 deletions src/parser/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1632,155 +1632,29 @@ impl ToTokens for SelectTable {
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum JoinOperator {
Comma,
TypedJoin {
natural: bool,
join_type: Option<JoinType>,
},
TypedJoin(Option<JoinType>),
}

impl JoinOperator {
pub(crate) fn from_single(token: Token) -> Result<JoinOperator, ParserError> {
Ok(if let Some(ref jt) = token.1 {
if "CROSS".eq_ignore_ascii_case(jt) {
JoinOperator::TypedJoin {
natural: false,
join_type: Some(JoinType::Cross),
}
} else if "INNER".eq_ignore_ascii_case(jt) {
JoinOperator::TypedJoin {
natural: false,
join_type: Some(JoinType::Inner),
}
} else if "LEFT".eq_ignore_ascii_case(jt) {
JoinOperator::TypedJoin {
natural: false,
join_type: Some(JoinType::Left),
}
} else if "RIGHT".eq_ignore_ascii_case(jt) {
JoinOperator::TypedJoin {
natural: false,
join_type: Some(JoinType::Right),
}
} else if "FULL".eq_ignore_ascii_case(jt) {
JoinOperator::TypedJoin {
natural: false,
join_type: Some(JoinType::Full),
}
} else if "NATURAL".eq_ignore_ascii_case(jt) {
JoinOperator::TypedJoin {
natural: true,
join_type: None,
}
} else {
return Err(ParserError::Custom(format!(
"unsupported JOIN type: {}",
jt
)));
}
} else {
unreachable!()
})
}
pub(crate) fn from_couple(token: Token, name: Name) -> Result<JoinOperator, ParserError> {
Ok(if let Some(ref jt) = token.1 {
if "NATURAL".eq_ignore_ascii_case(jt) {
let join_type = if "INNER".eq_ignore_ascii_case(&name.0) {
JoinType::Inner
} else if "LEFT".eq_ignore_ascii_case(&name.0) {
JoinType::Left
} else if "RIGHT".eq_ignore_ascii_case(&name.0) {
JoinType::Right
} else if "FULL".eq_ignore_ascii_case(&name.0) {
JoinType::Full
} else if "CROSS".eq_ignore_ascii_case(&name.0) {
JoinType::Cross
} else {
return Err(ParserError::Custom(format!(
"unsupported JOIN type: {} {}",
jt, &name.0
)));
};
JoinOperator::TypedJoin {
natural: true,
join_type: Some(join_type),
}
} else if "OUTER".eq_ignore_ascii_case(&name.0) {
// If "OUTER" is present then there must also be one of "LEFT", "RIGHT", or "FULL"
let join_type = if "LEFT".eq_ignore_ascii_case(jt) {
JoinType::LeftOuter
} else if "RIGHT".eq_ignore_ascii_case(jt) {
JoinType::RightOuter
} else if "FULL".eq_ignore_ascii_case(jt) {
JoinType::FullOuter
} else {
return Err(ParserError::Custom(format!(
"unsupported JOIN type: {} {}",
jt, &name.0
)));
};
JoinOperator::TypedJoin {
natural: false,
join_type: Some(join_type),
}
} else if "LEFT".eq_ignore_ascii_case(jt) && "RIGHT".eq_ignore_ascii_case(&name.0) {
JoinOperator::TypedJoin {
natural: false,
join_type: Some(JoinType::Full),
}
} else if "OUTER".eq_ignore_ascii_case(jt) && "LEFT".eq_ignore_ascii_case(&name.0) {
// OUTER LEFT JOIN -> same as LEFT JOIN
JoinOperator::TypedJoin {
natural: false,
join_type: Some(JoinType::LeftOuter),
}
} else {
return Err(ParserError::Custom(format!(
"unsupported JOIN type: {} {}",
jt, &name.0
)));
}
} else {
unreachable!()
})
}
pub(crate) fn from_triple(
pub(crate) fn from(
token: Token,
n1: Name,
n2: Name,
n1: Option<Name>,
n2: Option<Name>,
) -> Result<JoinOperator, ParserError> {
Ok(if let Some(ref jt) = token.1 {
if "NATURAL".eq_ignore_ascii_case(jt) && "OUTER".eq_ignore_ascii_case(&n2.0) {
// If "OUTER" is present then there must also be one of "LEFT", "RIGHT", or "FULL"
let join_type = if "LEFT".eq_ignore_ascii_case(&n1.0) {
JoinType::LeftOuter
} else if "RIGHT".eq_ignore_ascii_case(&n1.0) {
JoinType::RightOuter
} else if "FULL".eq_ignore_ascii_case(&n1.0) {
JoinType::FullOuter
} else {
return Err(ParserError::Custom(format!(
"unsupported JOIN type: {} {} {}",
jt, &n1.0, &n2.0
)));
};
JoinOperator::TypedJoin {
natural: true,
join_type: Some(join_type),
}
} else if "OUTER".eq_ignore_ascii_case(jt)
&& "LEFT".eq_ignore_ascii_case(&n1.0)
&& "NATURAL".eq_ignore_ascii_case(&n2.0)
Ok(if let Some(ref t) = token.1 {
let mut jt = JoinType::try_from(t.as_ref())?;
for n in [&n1, &n2].into_iter().flatten() {
jt |= JoinType::try_from(n.0.as_ref())?;
}
if (jt & (JoinType::INNER | JoinType::OUTER)) == (JoinType::INNER | JoinType::OUTER)
|| (jt & (JoinType::OUTER | JoinType::LEFT | JoinType::RIGHT)) == JoinType::OUTER
{
JoinOperator::TypedJoin {
natural: true,
join_type: Some(JoinType::LeftOuter),
}
} else {
return Err(ParserError::Custom(format!(
"unsupported JOIN type: {} {} {}",
jt, &n1.0, &n2.0
"unsupported JOIN type: {} {:?} {:?}",
t, n1, n2
)));
}
JoinOperator::TypedJoin(Some(jt))
} else {
unreachable!()
})
Expand All @@ -1790,10 +1664,7 @@ impl ToTokens for JoinOperator {
fn to_tokens<S: TokenStream>(&self, s: &mut S) -> Result<(), S::Error> {
match self {
JoinOperator::Comma => s.append(TK_COMMA, None),
JoinOperator::TypedJoin { natural, join_type } => {
if *natural {
s.append(TK_JOIN_KW, Some("NATURAL"))?;
}
JoinOperator::TypedJoin(join_type) => {
if let Some(ref join_type) = join_type {
join_type.to_tokens(s)?;
}
Expand All @@ -1803,32 +1674,70 @@ impl ToTokens for JoinOperator {
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum JoinType {
Left, // same as LeftOuter
LeftOuter,
Inner,
Cross,
Right, // same as RightOuter
RightOuter,
Full, // same as FullOuter
FullOuter,
// https://github.com/sqlite/sqlite/blob/80511f32f7e71062026edd471913ef0455563964/src/select.c#L197-L257
bitflags::bitflags! {
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct JoinType: u8 {
const INNER = 0x01;
/// cross => INNER|CROSS
const CROSS = 0x02;
const NATURAL = 0x04;
/// left => LEFT|OUTER
const LEFT = 0x08;
/// right => RIGHT|OUTER
const RIGHT = 0x10;
const OUTER = 0x20;
}
}

impl TryFrom<&str> for JoinType {
type Error = ParserError;
fn try_from(s: &str) -> Result<JoinType, ParserError> {
if "CROSS".eq_ignore_ascii_case(s) {
Ok(JoinType::INNER | JoinType::CROSS)
} else if "FULL".eq_ignore_ascii_case(s) {
Ok(JoinType::LEFT | JoinType::RIGHT | JoinType::OUTER)
} else if "INNER".eq_ignore_ascii_case(s) {
Ok(JoinType::INNER)
} else if "LEFT".eq_ignore_ascii_case(s) {
Ok(JoinType::LEFT | JoinType::OUTER)
} else if "NATURAL".eq_ignore_ascii_case(s) {
Ok(JoinType::NATURAL)
} else if "RIGHT".eq_ignore_ascii_case(s) {
Ok(JoinType::RIGHT | JoinType::OUTER)
} else if "OUTER".eq_ignore_ascii_case(s) {
Ok(JoinType::OUTER)
} else {
Err(ParserError::Custom(format!("unsupported JOIN type: {}", s)))
}
}
}

impl ToTokens for JoinType {
fn to_tokens<S: TokenStream>(&self, s: &mut S) -> Result<(), S::Error> {
s.append(
TK_JOIN_KW,
match self {
JoinType::Left => Some("LEFT"),
JoinType::LeftOuter => Some("LEFT OUTER"),
JoinType::Inner => Some("INNER"),
JoinType::Cross => Some("CROSS"),
JoinType::Right => Some("RIGHT"),
JoinType::RightOuter => Some("RIGHT OUTER"),
JoinType::Full => Some("FULL"),
JoinType::FullOuter => Some("FULL OUTER"),
},
)
if self.contains(JoinType::NATURAL) {
s.append(TK_JOIN_KW, Some("NATURAL"))?;
}
if self.contains(JoinType::INNER) {
if self.contains(JoinType::CROSS) {
s.append(TK_JOIN_KW, Some("CROSS"))?;
}
s.append(TK_JOIN_KW, Some("INNER"))?;
} else {
if self.contains(JoinType::LEFT) {
if self.contains(JoinType::RIGHT) {
s.append(TK_JOIN_KW, Some("FULL"))?;
} else {
s.append(TK_JOIN_KW, Some("LEFT"))?;
}
} else if self.contains(JoinType::RIGHT) {
s.append(TK_JOIN_KW, Some("RIGHT"))?;
}
if self.contains(JoinType::OUTER) {
s.append(TK_JOIN_KW, Some("OUTER"))?;
}
}
Ok(())
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/parser/parse.y
Original file line number Diff line number Diff line change
Expand Up @@ -657,13 +657,13 @@ xfullname(A) ::= nm(X) AS nm(Z). {

%type joinop {JoinOperator}
joinop(X) ::= COMMA. { X = JoinOperator::Comma; }
joinop(X) ::= JOIN. { X = JoinOperator::TypedJoin{ natural: false, join_type: None }; }
joinop(X) ::= JOIN. { X = JoinOperator::TypedJoin(None); }
joinop(X) ::= JOIN_KW(A) JOIN.
{X = JoinOperator::from_single(A)?; /*X-overwrites-A*/}
{X = JoinOperator::from(A, None, None)?; /*X-overwrites-A*/}
joinop(X) ::= JOIN_KW(A) nm(B) JOIN.
{X = JoinOperator::from_couple(A, B)?; /*X-overwrites-A*/}
{X = JoinOperator::from(A, Some(B), None)?; /*X-overwrites-A*/}
joinop(X) ::= JOIN_KW(A) nm(B) nm(C) JOIN.
{X = JoinOperator::from_triple(A, B, C)?;/*X-overwrites-A*/}
{X = JoinOperator::from(A, Some(B), Some(C))?;/*X-overwrites-A*/}

// There is a parsing abiguity in an upsert statement that uses a
// SELECT on the RHS of a the INSERT:
Expand Down

0 comments on commit a26f13e

Please sign in to comment.