From abd971dbb86e0ee18d26a7dace6243cd047f02a4 Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Tue, 19 Nov 2024 15:37:52 -0800 Subject: [PATCH 01/11] Tokenize at signs separately from subsequent strings This is needed for parsing MySQL-style `'user'@'host'` grantee syntax. As far as I can tell, no dialect allows quotes or backticks as part of an identifier (regardless of whether it starts with `@`) without other specific syntax (e.g. nested in another quote style and thus not starting with `@`), so this shouldn't adversely affect non-MySQL dialects. --- src/tokenizer.rs | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 05aaf1e28..c5751c361 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -1202,6 +1202,9 @@ impl<'a> Tokenizer<'a> { } } Some(' ') => Ok(Some(Token::AtSign)), + Some('\'') => Ok(Some(Token::AtSign)), + Some('\"') => Ok(Some(Token::AtSign)), + Some('`') => Ok(Some(Token::AtSign)), Some(sch) if self.dialect.is_identifier_start('@') => { self.tokenize_identifier_or_keyword([ch, *sch], chars) } @@ -1229,7 +1232,7 @@ impl<'a> Tokenizer<'a> { } '$' => Ok(Some(self.tokenize_dollar_preceded_value(chars)?)), - //whitespace check (including unicode chars) should be last as it covers some of the chars above + // whitespace check (including unicode chars) should be last as it covers some of the chars above ch if ch.is_whitespace() => { self.consume_and_return(chars, Token::Whitespace(Whitespace::Space)) } @@ -2976,4 +2979,22 @@ mod tests { let expected = vec![Token::SingleQuotedString("''".to_string())]; compare(expected, tokens); } + + #[test] + fn test_mysql_users_grantees() { + let dialect = MySqlDialect {}; + + let sql = "CREATE USER 'root'@'localhost'"; + let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap(); + let expected = vec![ + Token::make_keyword("CREATE"), + Token::Whitespace(Whitespace::Space), + Token::make_keyword("USER"), + Token::Whitespace(Whitespace::Space), + Token::SingleQuotedString("root".to_string()), + Token::AtSign, + Token::SingleQuotedString("localhost".to_string()), + ]; + compare(expected, tokens); + } } From 9dc0d9e3638a08612b01a6f3dd384c8bec34a156 Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Tue, 19 Nov 2024 17:47:27 -0800 Subject: [PATCH 02/11] Fix parsing GRANT/REVOKE for MySQL Introduces a new `Grantee` enum to differentiate bare identifiers (every other database: `root`) from user/host pairs (MySQL: `'root'@'%'`). While we're here, make the CASCADE/RESTRICT syntax for REVOKE optional since Postgres doesn't require it and MySQL doesn't allow it. Add support for MySQL wildcard object syntax: `GRANT ALL ON *.* ...` --- src/ast/mod.rs | 32 ++++++++++++++--- src/parser/mod.rs | 74 +++++++++++++++++++++++++++++++-------- tests/sqlparser_common.rs | 30 ++++++++++++++-- tests/sqlparser_mysql.rs | 52 +++++++++++++++++++++++++++ 4 files changed, 167 insertions(+), 21 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 9185c9df4..3957e4570 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3121,7 +3121,7 @@ pub enum Statement { Grant { privileges: Privileges, objects: GrantObjects, - grantees: Vec, + grantees: Vec, with_grant_option: bool, granted_by: Option, }, @@ -3131,9 +3131,9 @@ pub enum Statement { Revoke { privileges: Privileges, objects: GrantObjects, - grantees: Vec, + grantees: Vec, granted_by: Option, - cascade: bool, + cascade: Option, }, /// ```sql /// DEALLOCATE [ PREPARE ] { name | ALL } @@ -4660,7 +4660,9 @@ impl fmt::Display for Statement { if let Some(grantor) = granted_by { write!(f, " GRANTED BY {grantor}")?; } - write!(f, " {}", if *cascade { "CASCADE" } else { "RESTRICT" })?; + if let Some(cascade) = cascade { + write!(f, " {}", if *cascade { "CASCADE" } else { "RESTRICT" })?; + } Ok(()) } Statement::Deallocate { name, prepare } => write!( @@ -5381,6 +5383,28 @@ impl fmt::Display for GrantObjects { } } +/// Users/roles designated in a GRANT/REVOKE +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum Grantee { + /// A bare identifier + Ident(Ident), + /// A MySQL user/host pair such as 'root'@'%' + UserHost { user: Ident, host: Ident }, +} + +impl fmt::Display for Grantee { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Grantee::Ident(ident) => ident.fmt(f), + Grantee::UserHost { user, host } => { + write!(f, "{}@{}", user, host) + } + } + } +} + /// SQL assignment `foo = expr` as used in SQLUpdate #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 35c763e93..5bfed587a 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -8353,6 +8353,31 @@ impl<'a> Parser<'a> { } } + /// Parse a possibly qualified, possibly quoted identifier, optionally allowing for wildcards, + /// e.g. *, `foo`.*, or "foo"."bar" + pub fn parse_object_name_with_wildcards( + &mut self, + in_table_clause: bool, + allow_wildcards: bool, + ) -> Result { + let mut idents = vec![]; + loop { + let ident = if allow_wildcards && self.consume_token(&Token::Mul) { + Ident { + value: "*".to_owned(), + quote_style: None, + } + } else { + self.parse_identifier(in_table_clause)? + }; + idents.push(ident); + if !self.consume_token(&Token::Period) { + break; + } + } + Ok(ObjectName(idents)) + } + /// Parse a possibly qualified, possibly quoted identifier, e.g. /// `foo` or `myschema."table" /// @@ -8360,13 +8385,8 @@ impl<'a> Parser<'a> { /// or similar table clause. Currently, this is used only to support unquoted hyphenated identifiers /// in this context on BigQuery. pub fn parse_object_name(&mut self, in_table_clause: bool) -> Result { - let mut idents = vec![]; - loop { - idents.push(self.parse_identifier(in_table_clause)?); - if !self.consume_token(&Token::Period) { - break; - } - } + let ObjectName(mut idents) = + self.parse_object_name_with_wildcards(in_table_clause, false)?; // BigQuery accepts any number of quoted identifiers of a table name. // https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#quoted_identifiers @@ -10897,7 +10917,7 @@ impl<'a> Parser<'a> { let (privileges, objects) = self.parse_grant_revoke_privileges_objects()?; self.expect_keyword(Keyword::TO)?; - let grantees = self.parse_comma_separated(|p| p.parse_identifier(false))?; + let grantees = self.parse_comma_separated(|p| p.parse_grantee())?; let with_grant_option = self.parse_keywords(&[Keyword::WITH, Keyword::GRANT, Keyword::OPTION]); @@ -10979,7 +10999,11 @@ impl<'a> Parser<'a> { } else { let object_type = self.parse_one_of_keywords(&[Keyword::SEQUENCE, Keyword::SCHEMA, Keyword::TABLE]); - let objects = self.parse_comma_separated(|p| p.parse_object_name(false)); + let objects = if dialect_of!(self is MySqlDialect | GenericDialect) { + self.parse_comma_separated(|p| p.parse_object_name_with_wildcards(false, true)) + } else { + self.parse_comma_separated(|p| p.parse_object_name(false)) + }; match object_type { Some(Keyword::SCHEMA) => GrantObjects::Schemas(objects?), Some(Keyword::SEQUENCE) => GrantObjects::Sequences(objects?), @@ -11023,23 +11047,43 @@ impl<'a> Parser<'a> { } } + pub fn parse_grantee(&mut self) -> Result { + let user = self.parse_identifier(false)?; + if dialect_of!(self is MySqlDialect | GenericDialect) && self.consume_token(&Token::AtSign) + { + let host = self.parse_identifier(false)?; + Ok(Grantee::UserHost { user, host }) + } else { + Ok(Grantee::Ident(user)) + } + } + /// Parse a REVOKE statement pub fn parse_revoke(&mut self) -> Result { let (privileges, objects) = self.parse_grant_revoke_privileges_objects()?; self.expect_keyword(Keyword::FROM)?; - let grantees = self.parse_comma_separated(|p| p.parse_identifier(false))?; + let grantees = self.parse_comma_separated(|p| p.parse_grantee())?; let granted_by = self .parse_keywords(&[Keyword::GRANTED, Keyword::BY]) .then(|| self.parse_identifier(false).unwrap()); let loc = self.peek_token().location; - let cascade = self.parse_keyword(Keyword::CASCADE); - let restrict = self.parse_keyword(Keyword::RESTRICT); - if cascade && restrict { - return parser_err!("Cannot specify both CASCADE and RESTRICT in REVOKE", loc); - } + let cascade = if !dialect_of!(self is MySqlDialect) { + let cascade = self.parse_keyword(Keyword::CASCADE); + let restrict = self.parse_keyword(Keyword::RESTRICT); + if cascade && restrict { + return parser_err!("Cannot specify both CASCADE and RESTRICT in REVOKE", loc); + } + if cascade || restrict { + Some(cascade) + } else { + None + } + } else { + None + }; Ok(Statement::Revoke { privileges, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 3d9ba5da2..b951aa7d2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8197,14 +8197,40 @@ fn parse_grant() { #[test] fn test_revoke() { - let sql = "REVOKE ALL PRIVILEGES ON users, auth FROM analyst CASCADE"; + let sql = "REVOKE ALL PRIVILEGES ON users, auth FROM analyst"; match verified_stmt(sql) { Statement::Revoke { privileges, objects: GrantObjects::Tables(tables), grantees, + granted_by, cascade, + } => { + assert_eq!( + Privileges::All { + with_privileges_keyword: true + }, + privileges + ); + assert_eq_vec(&["users", "auth"], &tables); + assert_eq_vec(&["analyst"], &grantees); + assert_eq!(cascade, None); + assert_eq!(None, granted_by); + } + _ => unreachable!(), + } +} + +#[test] +fn test_revoke_with_cascade() { + let sql = "REVOKE ALL PRIVILEGES ON users, auth FROM analyst CASCADE"; + match all_dialects_except(|d| d.is::()).verified_stmt(sql) { + Statement::Revoke { + privileges, + objects: GrantObjects::Tables(tables), + grantees, granted_by, + cascade, } => { assert_eq!( Privileges::All { @@ -8214,7 +8240,7 @@ fn test_revoke() { ); assert_eq_vec(&["users", "auth"], &tables); assert_eq_vec(&["analyst"], &grantees); - assert!(cascade); + assert_eq!(cascade, Some(true)); assert_eq!(None, granted_by); } _ => unreachable!(), diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index ce3296737..9cff2e6c7 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -2971,3 +2971,55 @@ fn parse_bitstring_literal() { ))] ); } + +#[test] +fn parse_grant() { + let sql = "GRANT ALL ON *.* TO 'jeffrey'@'%'"; + assert_eq!( + mysql().verified_stmt(sql), + Statement::Grant { + privileges: Privileges::All { + with_privileges_keyword: false + }, + objects: GrantObjects::Tables(vec![ObjectName(vec!["*".into(), "*".into()])]), + grantees: vec![Grantee::UserHost { + user: Ident { + value: "jeffrey".to_owned(), + quote_style: Some('\'') + }, + host: Ident { + value: "%".to_owned(), + quote_style: Some('\'') + } + }], + with_grant_option: false, + granted_by: None + } + ) +} + +#[test] +fn parse_revoke() { + let sql = "REVOKE ALL ON db1.* FROM 'jeffrey'@'%'"; + assert_eq!( + mysql().verified_stmt(sql), + Statement::Revoke { + privileges: Privileges::All { + with_privileges_keyword: false + }, + objects: GrantObjects::Tables(vec![ObjectName(vec!["db1".into(), "*".into()])]), + grantees: vec![Grantee::UserHost { + user: Ident { + value: "jeffrey".to_owned(), + quote_style: Some('\'') + }, + host: Ident { + value: "%".to_owned(), + quote_style: Some('\'') + } + }], + granted_by: None, + cascade: None, + } + ) +} From 64669ede463f62b7b016e3298d7799bef8bdd516 Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Tue, 19 Nov 2024 18:03:23 -0800 Subject: [PATCH 03/11] Parse MySQL CREATE VIEW parameters --- src/ast/mod.rs | 87 ++++++++++++++++++++++++++++++++++-- src/keywords.rs | 5 +++ src/parser/mod.rs | 58 +++++++++++++++++++++++- tests/sqlparser_common.rs | 28 +++++++++--- tests/sqlparser_mysql.rs | 94 +++++++++++++++++++++++++++++++++++---- 5 files changed, 252 insertions(+), 20 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 3957e4570..bae0e8e75 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -819,7 +819,7 @@ pub enum Expr { /// Example: /// /// ```sql - /// SELECT (SELECT ',' + name FROM sys.objects FOR XML PATH(''), TYPE).value('.','NVARCHAR(MAX)') + /// SELECT (SELECT ',' + name FROM sys.objects FOR XML PATH(''), TYPE).value('.','NVARCHAR(MAX)') /// SELECT CONVERT(XML,'abc').value('.','NVARCHAR(MAX)').value('.','NVARCHAR(MAX)') /// ``` /// @@ -2427,6 +2427,8 @@ pub enum Statement { /// if not None, has Clickhouse `TO` clause, specify the table into which to insert results /// to: Option, + /// MySQL: Optional parameters for the view algorithm, definer, and security context + params: Option, }, /// ```sql /// CREATE TABLE @@ -3939,11 +3941,19 @@ impl fmt::Display for Statement { if_not_exists, temporary, to, + params, } => { write!( f, - "CREATE {or_replace}{materialized}{temporary}VIEW {if_not_exists}{name}{to}", + "CREATE {or_replace}", or_replace = if *or_replace { "OR REPLACE " } else { "" }, + )?; + if let Some(params) = params { + write!(f, "{params} ")?; + } + write!( + f, + "{materialized}{temporary}VIEW {if_not_exists}{name}{to}", materialized = if *materialized { "MATERIALIZED " } else { "" }, name = name, temporary = if *temporary { "TEMPORARY " } else { "" }, @@ -7335,15 +7345,84 @@ pub enum MySQLColumnPosition { impl Display for MySQLColumnPosition { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - MySQLColumnPosition::First => Ok(write!(f, "FIRST")?), + MySQLColumnPosition::First => write!(f, "FIRST"), MySQLColumnPosition::After(ident) => { let column_name = &ident.value; - Ok(write!(f, "AFTER {column_name}")?) + write!(f, "AFTER {column_name}") } } } } +/// MySQL `CREATE VIEW` algorithm parameter: [ALGORITHM = {UNDEFINED | MERGE | TEMPTABLE}] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum MySQLViewAlgorithm { + Undefined, + Merge, + TempTable, +} + +impl Display for MySQLViewAlgorithm { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + MySQLViewAlgorithm::Undefined => write!(f, "UNDEFINED"), + MySQLViewAlgorithm::Merge => write!(f, "MERGE"), + MySQLViewAlgorithm::TempTable => write!(f, "TEMPTABLE"), + } + } +} +/// MySQL `CREATE VIEW` security parameter: [SQL SECURITY { DEFINER | INVOKER }] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum MySQLViewSecurity { + Definer, + Invoker, +} + +impl Display for MySQLViewSecurity { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + MySQLViewSecurity::Definer => write!(f, "DEFINER"), + MySQLViewSecurity::Invoker => write!(f, "INVOKER"), + } + } +} + +/// [MySQL] `CREATE VIEW` additional parameters +/// +/// [MySQL]: https://dev.mysql.com/doc/refman/9.1/en/create-view.html +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct MySQLViewParams { + pub algorithm: Option, + pub definer: Option, + pub security: Option, +} + +impl Display for MySQLViewParams { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let parts = [ + self.algorithm + .as_ref() + .map(|algorithm| format!("ALGORITHM = {algorithm}")), + self.definer + .as_ref() + .map(|definer| format!("DEFINER = {definer}")), + self.security + .as_ref() + .map(|security| format!("SQL SECURITY {security}")), + ] + .into_iter() + .flat_map(|part| part) + .collect::>(); + display_separated(&parts, " ").fmt(f) + } +} + /// Engine of DB. Some warehouse has parameters of engine, e.g. [clickhouse] /// /// [clickhouse]: https://clickhouse.com/docs/en/engines/table-engines diff --git a/src/keywords.rs b/src/keywords.rs index fc2a2927c..b6dc93323 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -84,6 +84,7 @@ define_keywords!( AFTER, AGAINST, AGGREGATION, + ALGORITHM, ALIAS, ALL, ALLOCATE, @@ -241,6 +242,7 @@ define_keywords!( DEFERRED, DEFINE, DEFINED, + DEFINER, DELAYED, DELETE, DELIMITED, @@ -412,6 +414,7 @@ define_keywords!( INTERSECTION, INTERVAL, INTO, + INVOKER, IS, ISODOW, ISOLATION, @@ -750,6 +753,7 @@ define_keywords!( TBLPROPERTIES, TEMP, TEMPORARY, + TEMPTABLE, TERMINATED, TERSE, TEXT, @@ -795,6 +799,7 @@ define_keywords!( UNBOUNDED, UNCACHE, UNCOMMITTED, + UNDEFINED, UNFREEZE, UNION, UNIQUE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 5bfed587a..986a1855f 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -3720,11 +3720,16 @@ impl<'a> Parser<'a> { .is_some(); let persistent = dialect_of!(self is DuckDbDialect) && self.parse_one_of_keywords(&[Keyword::PERSISTENT]).is_some(); + let mysql_create_view_params = if dialect_of!(self is MySqlDialect | GenericDialect) { + self.parse_mysql_create_view_params()? + } else { + None + }; if self.parse_keyword(Keyword::TABLE) { self.parse_create_table(or_replace, temporary, global, transient) } else if self.parse_keyword(Keyword::MATERIALIZED) || self.parse_keyword(Keyword::VIEW) { self.prev_token(); - self.parse_create_view(or_replace, temporary) + self.parse_create_view(or_replace, temporary, mysql_create_view_params) } else if self.parse_keyword(Keyword::POLICY) { self.parse_create_policy() } else if self.parse_keyword(Keyword::EXTERNAL) { @@ -4616,6 +4621,7 @@ impl<'a> Parser<'a> { &mut self, or_replace: bool, temporary: bool, + mysql_create_view_params: Option, ) -> Result { let materialized = self.parse_keyword(Keyword::MATERIALIZED); self.expect_keyword(Keyword::VIEW)?; @@ -4693,9 +4699,59 @@ impl<'a> Parser<'a> { if_not_exists, temporary, to, + params: mysql_create_view_params, }) } + /// Parse optional algorithm, definer, and security context parameters for [MySQL] + /// + /// [MySQL]: https://dev.mysql.com/doc/refman/9.1/en/create-view.html + fn parse_mysql_create_view_params(&mut self) -> Result, ParserError> { + let algorithm = if self.parse_keyword(Keyword::ALGORITHM) { + self.expect_token(&Token::Eq)?; + Some( + match self.expect_one_of_keywords(&[ + Keyword::UNDEFINED, + Keyword::MERGE, + Keyword::TEMPTABLE, + ])? { + Keyword::UNDEFINED => MySQLViewAlgorithm::Undefined, + Keyword::MERGE => MySQLViewAlgorithm::Merge, + Keyword::TEMPTABLE => MySQLViewAlgorithm::TempTable, + _ => unreachable!(), + }, + ) + } else { + None + }; + let definer = if self.parse_keyword(Keyword::DEFINER) { + self.expect_token(&Token::Eq)?; + Some(self.parse_grantee()?) + } else { + None + }; + let security = if self.parse_keywords(&[Keyword::SQL, Keyword::SECURITY]) { + Some( + match self.expect_one_of_keywords(&[Keyword::DEFINER, Keyword::INVOKER])? { + Keyword::DEFINER => MySQLViewSecurity::Definer, + Keyword::INVOKER => MySQLViewSecurity::Invoker, + _ => unreachable!(), + }, + ) + } else { + None + }; + if algorithm.is_some() || definer.is_some() || security.is_some() { + Ok(Some(MySQLViewParams { + algorithm, + definer, + security, + })) + } else { + Ok(None) + } + } + pub fn parse_create_role(&mut self) -> Result { let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let names = self.parse_comma_separated(|p| p.parse_object_name(false))?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index b951aa7d2..6cd1b3b34 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -6839,6 +6839,7 @@ fn parse_create_view() { if_not_exists, temporary, to, + params, } => { assert_eq!("myschema.myview", name.to_string()); assert_eq!(Vec::::new(), columns); @@ -6851,7 +6852,8 @@ fn parse_create_view() { assert!(!late_binding); assert!(!if_not_exists); assert!(!temporary); - assert!(to.is_none()) + assert!(to.is_none()); + assert!(params.is_none()); } _ => unreachable!(), } @@ -6899,6 +6901,7 @@ fn parse_create_view_with_columns() { if_not_exists, temporary, to, + params, } => { assert_eq!("v", name.to_string()); assert_eq!( @@ -6921,7 +6924,8 @@ fn parse_create_view_with_columns() { assert!(!late_binding); assert!(!if_not_exists); assert!(!temporary); - assert!(to.is_none()) + assert!(to.is_none()); + assert!(params.is_none()); } _ => unreachable!(), } @@ -6944,6 +6948,7 @@ fn parse_create_view_temporary() { if_not_exists, temporary, to, + params, } => { assert_eq!("myschema.myview", name.to_string()); assert_eq!(Vec::::new(), columns); @@ -6956,7 +6961,8 @@ fn parse_create_view_temporary() { assert!(!late_binding); assert!(!if_not_exists); assert!(temporary); - assert!(to.is_none()) + assert!(to.is_none()); + assert!(params.is_none()); } _ => unreachable!(), } @@ -6979,6 +6985,7 @@ fn parse_create_or_replace_view() { if_not_exists, temporary, to, + params, } => { assert_eq!("v", name.to_string()); assert_eq!(columns, vec![]); @@ -6991,7 +6998,8 @@ fn parse_create_or_replace_view() { assert!(!late_binding); assert!(!if_not_exists); assert!(!temporary); - assert!(to.is_none()) + assert!(to.is_none()); + assert!(params.is_none()); } _ => unreachable!(), } @@ -7018,6 +7026,7 @@ fn parse_create_or_replace_materialized_view() { if_not_exists, temporary, to, + params, } => { assert_eq!("v", name.to_string()); assert_eq!(columns, vec![]); @@ -7030,7 +7039,8 @@ fn parse_create_or_replace_materialized_view() { assert!(!late_binding); assert!(!if_not_exists); assert!(!temporary); - assert!(to.is_none()) + assert!(to.is_none()); + assert!(params.is_none()); } _ => unreachable!(), } @@ -7053,6 +7063,7 @@ fn parse_create_materialized_view() { if_not_exists, temporary, to, + params, } => { assert_eq!("myschema.myview", name.to_string()); assert_eq!(Vec::::new(), columns); @@ -7065,7 +7076,8 @@ fn parse_create_materialized_view() { assert!(!late_binding); assert!(!if_not_exists); assert!(!temporary); - assert!(to.is_none()) + assert!(to.is_none()); + assert!(params.is_none()); } _ => unreachable!(), } @@ -7088,6 +7100,7 @@ fn parse_create_materialized_view_with_cluster_by() { if_not_exists, temporary, to, + params, } => { assert_eq!("myschema.myview", name.to_string()); assert_eq!(Vec::::new(), columns); @@ -7100,7 +7113,8 @@ fn parse_create_materialized_view_with_cluster_by() { assert!(!late_binding); assert!(!if_not_exists); assert!(!temporary); - assert!(to.is_none()) + assert!(to.is_none()); + assert!(params.is_none()); } _ => unreachable!(), } diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 9cff2e6c7..1b4cca2b4 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -31,6 +31,14 @@ use test_utils::*; #[macro_use] mod test_utils; +fn mysql() -> TestedDialects { + TestedDialects::new(vec![Box::new(MySqlDialect {})]) +} + +fn mysql_and_generic() -> TestedDialects { + TestedDialects::new(vec![Box::new(MySqlDialect {}), Box::new(GenericDialect {})]) +} + #[test] fn parse_identifiers() { mysql().verified_stmt("SELECT $a$, àà"); @@ -2704,14 +2712,6 @@ fn parse_create_table_with_fulltext_definition_should_not_accept_constraint_name mysql_and_generic().verified_stmt("CREATE TABLE tb (c1 INT, CONSTRAINT cons FULLTEXT (c1))"); } -fn mysql() -> TestedDialects { - TestedDialects::new(vec![Box::new(MySqlDialect {})]) -} - -fn mysql_and_generic() -> TestedDialects { - TestedDialects::new(vec![Box::new(MySqlDialect {}), Box::new(GenericDialect {})]) -} - #[test] fn parse_values() { mysql().verified_stmt("VALUES ROW(1, true, 'a')"); @@ -3023,3 +3023,81 @@ fn parse_revoke() { } ) } + +#[test] +fn parse_create_view_algorithm_param() { + let sql = "CREATE ALGORITHM = MERGE VIEW foo AS SELECT 1"; + let stmt = mysql().verified_stmt(sql); + if let Statement::CreateView { + params: + Some(MySQLViewParams { + algorithm, + definer, + security, + }), + .. + } = stmt + { + assert_eq!(algorithm, Some(MySQLViewAlgorithm::Merge)); + assert!(definer.is_none()); + assert!(security.is_none()); + } else { + unreachable!() + } +} + +#[test] +fn parse_create_view_definer_param() { + let sql = "CREATE DEFINER = 'jeffrey'@'localhost' VIEW foo AS SELECT 1"; + let stmt = mysql().verified_stmt(sql); + if let Statement::CreateView { + params: + Some(MySQLViewParams { + algorithm, + definer, + security, + }), + .. + } = stmt + { + assert!(algorithm.is_none()); + assert_eq!( + definer, + Some(Grantee::UserHost { + user: Ident { + value: "jeffrey".to_owned(), + quote_style: Some('\'') + }, + host: Ident { + value: "localhost".to_owned(), + quote_style: Some('\'') + }, + }) + ); + assert!(security.is_none()); + } else { + unreachable!() + } +} + +#[test] +fn parse_create_view_security_param() { + let sql = "CREATE SQL SECURITY DEFINER VIEW foo AS SELECT 1"; + let stmt = mysql().verified_stmt(sql); + if let Statement::CreateView { + params: + Some(MySQLViewParams { + algorithm, + definer, + security, + }), + .. + } = stmt + { + assert!(algorithm.is_none()); + assert!(definer.is_none()); + assert_eq!(security, Some(MySQLViewSecurity::Definer)); + } else { + unreachable!() + } +} From 69bb0d824b360e60c358aad1367d0d897d0c5fea Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Wed, 20 Nov 2024 09:02:55 -0800 Subject: [PATCH 04/11] Fix lint error and no-std build --- src/ast/mod.rs | 2 +- src/parser/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index bae0e8e75..033f89c22 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -7417,7 +7417,7 @@ impl Display for MySQLViewParams { .map(|security| format!("SQL SECURITY {security}")), ] .into_iter() - .flat_map(|part| part) + .flatten() .collect::>(); display_separated(&parts, " ").fmt(f) } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 986a1855f..0aeae7b1d 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -8420,7 +8420,7 @@ impl<'a> Parser<'a> { loop { let ident = if allow_wildcards && self.consume_token(&Token::Mul) { Ident { - value: "*".to_owned(), + value: Token::Mul.to_string(), quote_style: None, } } else { From 797c510a8d79324165b4d7ee91ecaff32d9451fa Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Wed, 27 Nov 2024 09:47:40 -0800 Subject: [PATCH 05/11] Add more test cases for create view --- tests/sqlparser_mysql.rs | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 1b4cca2b4..c4cfefd13 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -3002,7 +3002,7 @@ fn parse_grant() { fn parse_revoke() { let sql = "REVOKE ALL ON db1.* FROM 'jeffrey'@'%'"; assert_eq!( - mysql().verified_stmt(sql), + mysql_and_generic().verified_stmt(sql), Statement::Revoke { privileges: Privileges::All { with_privileges_keyword: false @@ -3044,6 +3044,8 @@ fn parse_create_view_algorithm_param() { } else { unreachable!() } + mysql().verified_stmt("CREATE ALGORITHM = UNDEFINED VIEW foo AS SELECT 1"); + mysql().verified_stmt("CREATE ALGORITHM = TEMPTABLE VIEW foo AS SELECT 1"); } #[test] @@ -3100,4 +3102,39 @@ fn parse_create_view_security_param() { } else { unreachable!() } + mysql().verified_stmt("CREATE SQL SECURITY INVOKER VIEW foo AS SELECT 1"); +} + +#[test] +fn parse_create_view_multiple_params() { + let sql = "CREATE ALGORITHM = UNDEFINED DEFINER = `root`@`%` SQL SECURITY INVOKER VIEW foo AS SELECT 1"; + let stmt = mysql().verified_stmt(sql); + if let Statement::CreateView { + params: + Some(MySQLViewParams { + algorithm, + definer, + security, + }), + .. + } = stmt + { + assert_eq!(algorithm, Some(MySQLViewAlgorithm::Undefined)); + assert_eq!( + definer, + Some(Grantee::UserHost { + user: Ident { + value: "root".to_owned(), + quote_style: Some('`') + }, + host: Ident { + value: "%".to_owned(), + quote_style: Some('`') + }, + }) + ); + assert_eq!(security, Some(MySQLViewSecurity::Invoker)); + } else { + unreachable!() + } } From c896e10161ff5ed748362f49605349b4170f263d Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Wed, 27 Nov 2024 09:52:40 -0800 Subject: [PATCH 06/11] Use generic names for optional `CREATE VIEW` params Though as far as I know they are entirely MySQL specific. --- src/ast/mod.rs | 28 ++++++++++++++-------------- src/parser/mod.rs | 24 ++++++++++++------------ tests/sqlparser_mysql.rs | 16 ++++++++-------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 033f89c22..9a897521c 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2428,7 +2428,7 @@ pub enum Statement { /// to: Option, /// MySQL: Optional parameters for the view algorithm, definer, and security context - params: Option, + params: Option, }, /// ```sql /// CREATE TABLE @@ -7358,18 +7358,18 @@ impl Display for MySQLColumnPosition { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub enum MySQLViewAlgorithm { +pub enum CreateViewAlgorithm { Undefined, Merge, TempTable, } -impl Display for MySQLViewAlgorithm { +impl Display for CreateViewAlgorithm { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - MySQLViewAlgorithm::Undefined => write!(f, "UNDEFINED"), - MySQLViewAlgorithm::Merge => write!(f, "MERGE"), - MySQLViewAlgorithm::TempTable => write!(f, "TEMPTABLE"), + CreateViewAlgorithm::Undefined => write!(f, "UNDEFINED"), + CreateViewAlgorithm::Merge => write!(f, "MERGE"), + CreateViewAlgorithm::TempTable => write!(f, "TEMPTABLE"), } } } @@ -7377,16 +7377,16 @@ impl Display for MySQLViewAlgorithm { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub enum MySQLViewSecurity { +pub enum CreateViewSecurity { Definer, Invoker, } -impl Display for MySQLViewSecurity { +impl Display for CreateViewSecurity { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - MySQLViewSecurity::Definer => write!(f, "DEFINER"), - MySQLViewSecurity::Invoker => write!(f, "INVOKER"), + CreateViewSecurity::Definer => write!(f, "DEFINER"), + CreateViewSecurity::Invoker => write!(f, "INVOKER"), } } } @@ -7397,13 +7397,13 @@ impl Display for MySQLViewSecurity { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct MySQLViewParams { - pub algorithm: Option, +pub struct CreateViewParams { + pub algorithm: Option, pub definer: Option, - pub security: Option, + pub security: Option, } -impl Display for MySQLViewParams { +impl Display for CreateViewParams { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let parts = [ self.algorithm diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 0aeae7b1d..f4d2ceef1 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -3720,8 +3720,8 @@ impl<'a> Parser<'a> { .is_some(); let persistent = dialect_of!(self is DuckDbDialect) && self.parse_one_of_keywords(&[Keyword::PERSISTENT]).is_some(); - let mysql_create_view_params = if dialect_of!(self is MySqlDialect | GenericDialect) { - self.parse_mysql_create_view_params()? + let create_view_params = if dialect_of!(self is MySqlDialect | GenericDialect) { + self.parse_create_view_params()? } else { None }; @@ -3729,7 +3729,7 @@ impl<'a> Parser<'a> { self.parse_create_table(or_replace, temporary, global, transient) } else if self.parse_keyword(Keyword::MATERIALIZED) || self.parse_keyword(Keyword::VIEW) { self.prev_token(); - self.parse_create_view(or_replace, temporary, mysql_create_view_params) + self.parse_create_view(or_replace, temporary, create_view_params) } else if self.parse_keyword(Keyword::POLICY) { self.parse_create_policy() } else if self.parse_keyword(Keyword::EXTERNAL) { @@ -4621,7 +4621,7 @@ impl<'a> Parser<'a> { &mut self, or_replace: bool, temporary: bool, - mysql_create_view_params: Option, + create_view_params: Option, ) -> Result { let materialized = self.parse_keyword(Keyword::MATERIALIZED); self.expect_keyword(Keyword::VIEW)?; @@ -4699,14 +4699,14 @@ impl<'a> Parser<'a> { if_not_exists, temporary, to, - params: mysql_create_view_params, + params: create_view_params, }) } /// Parse optional algorithm, definer, and security context parameters for [MySQL] /// /// [MySQL]: https://dev.mysql.com/doc/refman/9.1/en/create-view.html - fn parse_mysql_create_view_params(&mut self) -> Result, ParserError> { + fn parse_create_view_params(&mut self) -> Result, ParserError> { let algorithm = if self.parse_keyword(Keyword::ALGORITHM) { self.expect_token(&Token::Eq)?; Some( @@ -4715,9 +4715,9 @@ impl<'a> Parser<'a> { Keyword::MERGE, Keyword::TEMPTABLE, ])? { - Keyword::UNDEFINED => MySQLViewAlgorithm::Undefined, - Keyword::MERGE => MySQLViewAlgorithm::Merge, - Keyword::TEMPTABLE => MySQLViewAlgorithm::TempTable, + Keyword::UNDEFINED => CreateViewAlgorithm::Undefined, + Keyword::MERGE => CreateViewAlgorithm::Merge, + Keyword::TEMPTABLE => CreateViewAlgorithm::TempTable, _ => unreachable!(), }, ) @@ -4733,8 +4733,8 @@ impl<'a> Parser<'a> { let security = if self.parse_keywords(&[Keyword::SQL, Keyword::SECURITY]) { Some( match self.expect_one_of_keywords(&[Keyword::DEFINER, Keyword::INVOKER])? { - Keyword::DEFINER => MySQLViewSecurity::Definer, - Keyword::INVOKER => MySQLViewSecurity::Invoker, + Keyword::DEFINER => CreateViewSecurity::Definer, + Keyword::INVOKER => CreateViewSecurity::Invoker, _ => unreachable!(), }, ) @@ -4742,7 +4742,7 @@ impl<'a> Parser<'a> { None }; if algorithm.is_some() || definer.is_some() || security.is_some() { - Ok(Some(MySQLViewParams { + Ok(Some(CreateViewParams { algorithm, definer, security, diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index c4cfefd13..7c35af82b 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -3030,7 +3030,7 @@ fn parse_create_view_algorithm_param() { let stmt = mysql().verified_stmt(sql); if let Statement::CreateView { params: - Some(MySQLViewParams { + Some(CreateViewParams { algorithm, definer, security, @@ -3038,7 +3038,7 @@ fn parse_create_view_algorithm_param() { .. } = stmt { - assert_eq!(algorithm, Some(MySQLViewAlgorithm::Merge)); + assert_eq!(algorithm, Some(CreateViewAlgorithm::Merge)); assert!(definer.is_none()); assert!(security.is_none()); } else { @@ -3054,7 +3054,7 @@ fn parse_create_view_definer_param() { let stmt = mysql().verified_stmt(sql); if let Statement::CreateView { params: - Some(MySQLViewParams { + Some(CreateViewParams { algorithm, definer, security, @@ -3088,7 +3088,7 @@ fn parse_create_view_security_param() { let stmt = mysql().verified_stmt(sql); if let Statement::CreateView { params: - Some(MySQLViewParams { + Some(CreateViewParams { algorithm, definer, security, @@ -3098,7 +3098,7 @@ fn parse_create_view_security_param() { { assert!(algorithm.is_none()); assert!(definer.is_none()); - assert_eq!(security, Some(MySQLViewSecurity::Definer)); + assert_eq!(security, Some(CreateViewSecurity::Definer)); } else { unreachable!() } @@ -3111,7 +3111,7 @@ fn parse_create_view_multiple_params() { let stmt = mysql().verified_stmt(sql); if let Statement::CreateView { params: - Some(MySQLViewParams { + Some(CreateViewParams { algorithm, definer, security, @@ -3119,7 +3119,7 @@ fn parse_create_view_multiple_params() { .. } = stmt { - assert_eq!(algorithm, Some(MySQLViewAlgorithm::Undefined)); + assert_eq!(algorithm, Some(CreateViewAlgorithm::Undefined)); assert_eq!( definer, Some(Grantee::UserHost { @@ -3133,7 +3133,7 @@ fn parse_create_view_multiple_params() { }, }) ); - assert_eq!(security, Some(MySQLViewSecurity::Invoker)); + assert_eq!(security, Some(CreateViewSecurity::Invoker)); } else { unreachable!() } From 935777b400c836fb7fa1ae6c82276b686ec70ace Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Wed, 27 Nov 2024 10:04:34 -0800 Subject: [PATCH 07/11] Return errors instead of panicking when expect_one_keyword returns something unexpected --- src/parser/mod.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index f4d2ceef1..c74d07f73 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -4718,7 +4718,12 @@ impl<'a> Parser<'a> { Keyword::UNDEFINED => CreateViewAlgorithm::Undefined, Keyword::MERGE => CreateViewAlgorithm::Merge, Keyword::TEMPTABLE => CreateViewAlgorithm::TempTable, - _ => unreachable!(), + _ => { + self.prev_token(); + let found = self.next_token(); + return self + .expected("UNDEFINED or MERGE or TEMPTABLE after ALGORITHM =", found); + } }, ) } else { @@ -4735,7 +4740,11 @@ impl<'a> Parser<'a> { match self.expect_one_of_keywords(&[Keyword::DEFINER, Keyword::INVOKER])? { Keyword::DEFINER => CreateViewSecurity::Definer, Keyword::INVOKER => CreateViewSecurity::Invoker, - _ => unreachable!(), + _ => { + self.prev_token(); + let found = self.next_token(); + return self.expected("DEFINER or INVOKER after SQL SECURITY", found); + } }, ) } else { From 99bf3f614262e256f0b1b9a1c129986ac2a8e6b8 Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Wed, 27 Nov 2024 13:47:13 -0800 Subject: [PATCH 08/11] Add dialect support toggle for MySQL grantee syntax --- src/dialect/generic.rs | 4 ++++ src/dialect/mod.rs | 5 +++++ src/dialect/mysql.rs | 4 ++++ src/parser/mod.rs | 3 +-- 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index e3beeae7f..f83430bf2 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -123,4 +123,8 @@ impl Dialect for GenericDialect { fn supports_named_fn_args_with_assignment_operator(&self) -> bool { true } + + fn supports_user_host_grantee(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 985cad749..6120880b5 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -365,6 +365,11 @@ pub trait Dialect: Debug + Any { self.supports_trailing_commas() } + /// Does the dialect support MySQL-style `'user'@'host'` grantee syntax? + fn supports_user_host_grantee(&self) -> bool { + false + } + /// Dialect-specific infix parser override /// /// This method is called to parse the next infix expression. diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 197ce48d4..aae2fc390 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -102,6 +102,10 @@ impl Dialect for MySqlDialect { fn supports_create_table_select(&self) -> bool { true } + + fn supports_user_host_grantee(&self) -> bool { + true + } } /// `LOCK TABLES` diff --git a/src/parser/mod.rs b/src/parser/mod.rs index c74d07f73..42c8d6d09 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11114,8 +11114,7 @@ impl<'a> Parser<'a> { pub fn parse_grantee(&mut self) -> Result { let user = self.parse_identifier(false)?; - if dialect_of!(self is MySqlDialect | GenericDialect) && self.consume_token(&Token::AtSign) - { + if self.dialect.supports_user_host_grantee() && self.consume_token(&Token::AtSign) { let host = self.parse_identifier(false)?; Ok(Grantee::UserHost { user, host }) } else { From 309165330e971fef10a598747f7902920234bf66 Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Wed, 27 Nov 2024 14:04:26 -0800 Subject: [PATCH 09/11] Refactor parsing revoke cascade/restrict option to reuse truncate code --- src/ast/mod.rs | 23 ++++++++++++++++------- src/parser/mod.rs | 34 ++++++++++++---------------------- tests/sqlparser_common.rs | 2 +- tests/sqlparser_postgres.rs | 4 ++-- 4 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 9a897521c..422f7e482 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2286,7 +2286,7 @@ pub enum Statement { identity: Option, /// Postgres-specific option /// [ CASCADE | RESTRICT ] - cascade: Option, + cascade: Option, /// ClickHouse-specific option /// [ ON CLUSTER cluster_name ] /// @@ -3135,7 +3135,7 @@ pub enum Statement { objects: GrantObjects, grantees: Vec, granted_by: Option, - cascade: Option, + cascade: Option, }, /// ```sql /// DEALLOCATE [ PREPARE ] { name | ALL } @@ -3551,8 +3551,8 @@ impl fmt::Display for Statement { } if let Some(cascade) = cascade { match cascade { - TruncateCascadeOption::Cascade => write!(f, " CASCADE")?, - TruncateCascadeOption::Restrict => write!(f, " RESTRICT")?, + CascadeOption::Cascade => write!(f, " CASCADE")?, + CascadeOption::Restrict => write!(f, " RESTRICT")?, } } @@ -4671,7 +4671,7 @@ impl fmt::Display for Statement { write!(f, " GRANTED BY {grantor}")?; } if let Some(cascade) = cascade { - write!(f, " {}", if *cascade { "CASCADE" } else { "RESTRICT" })?; + write!(f, " {}", cascade)?; } Ok(()) } @@ -5069,16 +5069,25 @@ pub enum TruncateIdentityOption { Continue, } -/// PostgreSQL cascade option for TRUNCATE table +/// Cascade/restrict option for Postgres TRUNCATE table, MySQL GRANT/REVOKE, etc. /// [ CASCADE | RESTRICT ] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub enum TruncateCascadeOption { +pub enum CascadeOption { Cascade, Restrict, } +impl Display for CascadeOption { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + CascadeOption::Cascade => write!(f, "CASCADE"), + CascadeOption::Restrict => write!(f, "RESTRICT"), + } + } +} + /// Can use to describe options in create sequence or table column type identity /// [ MINVALUE minvalue | NO MINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE ] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 42c8d6d09..fc23fb386 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -749,13 +749,7 @@ impl<'a> Parser<'a> { None }; - cascade = if self.parse_keyword(Keyword::CASCADE) { - Some(TruncateCascadeOption::Cascade) - } else if self.parse_keyword(Keyword::RESTRICT) { - Some(TruncateCascadeOption::Restrict) - } else { - None - }; + cascade = self.parse_cascade_option(); }; let on_cluster = self.parse_optional_on_cluster()?; @@ -771,6 +765,16 @@ impl<'a> Parser<'a> { }) } + fn parse_cascade_option(&mut self) -> Option { + if self.parse_keyword(Keyword::CASCADE) { + Some(CascadeOption::Cascade) + } else if self.parse_keyword(Keyword::RESTRICT) { + Some(CascadeOption::Restrict) + } else { + None + } + } + pub fn parse_attach_duckdb_database_options( &mut self, ) -> Result, ParserError> { @@ -11133,21 +11137,7 @@ impl<'a> Parser<'a> { .parse_keywords(&[Keyword::GRANTED, Keyword::BY]) .then(|| self.parse_identifier(false).unwrap()); - let loc = self.peek_token().location; - let cascade = if !dialect_of!(self is MySqlDialect) { - let cascade = self.parse_keyword(Keyword::CASCADE); - let restrict = self.parse_keyword(Keyword::RESTRICT); - if cascade && restrict { - return parser_err!("Cannot specify both CASCADE and RESTRICT in REVOKE", loc); - } - if cascade || restrict { - Some(cascade) - } else { - None - } - } else { - None - }; + let cascade = self.parse_cascade_option(); Ok(Statement::Revoke { privileges, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 6cd1b3b34..3c5de7447 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8254,7 +8254,7 @@ fn test_revoke_with_cascade() { ); assert_eq_vec(&["users", "auth"], &tables); assert_eq_vec(&["analyst"], &grantees); - assert_eq!(cascade, Some(true)); + assert_eq!(cascade, Some(CascadeOption::Cascade)); assert_eq!(None, granted_by); } _ => unreachable!(), diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 2e2c4403c..c30c4f81f 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -3994,7 +3994,7 @@ fn parse_truncate_with_options() { table: true, only: true, identity: Some(TruncateIdentityOption::Restart), - cascade: Some(TruncateCascadeOption::Cascade), + cascade: Some(CascadeOption::Cascade), on_cluster: None, }, truncate @@ -4026,7 +4026,7 @@ fn parse_truncate_with_table_list() { table: true, only: false, identity: Some(TruncateIdentityOption::Restart), - cascade: Some(TruncateCascadeOption::Cascade), + cascade: Some(CascadeOption::Cascade), on_cluster: None, }, truncate From 50506c49275deea1458b8f670d7bb8327be22428 Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Wed, 27 Nov 2024 14:26:58 -0800 Subject: [PATCH 10/11] Fix a doc comment about BigQuery identifiers --- src/parser/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index fc23fb386..511ed2936 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -8584,9 +8584,9 @@ impl<'a> Parser<'a> { /// Parse a simple one-word identifier (possibly quoted, possibly a keyword) /// - /// The `in_table_clause` parameter indicates whether the identifier is a table in a FROM, JOIN, or - /// similar table clause. Currently, this is used only to support unquoted hyphenated identifiers in - // this context on BigQuery. + /// The `in_table_clause` parameter indicates whether the identifier is a table in a FROM, JOIN, + /// or similar table clause. Currently, this is used only to support unquoted hyphenated + /// identifiers in this context on BigQuery. pub fn parse_identifier(&mut self, in_table_clause: bool) -> Result { let next_token = self.next_token(); match next_token.token { From 3fc2164c538dfb6c157058d2f9d0669ff0e90523 Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Wed, 27 Nov 2024 16:36:19 -0800 Subject: [PATCH 11/11] Accommodate new span field on ident for grantees and wildcards --- src/ast/spans.rs | 1 + src/parser/mod.rs | 4 +- tests/sqlparser_mysql.rs | 142 +++++++++++++++++++++------------------ 3 files changed, 81 insertions(+), 66 deletions(-) diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 8e8c7b14a..4be4f7e18 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -325,6 +325,7 @@ impl Spanned for Statement { if_not_exists: _, temporary: _, to, + params: _, } => union_spans( core::iter::once(name.span()) .chain(columns.iter().map(|i| i.span())) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index f9fe6d02e..de62a86e0 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -8520,10 +8520,12 @@ impl<'a> Parser<'a> { ) -> Result { let mut idents = vec![]; loop { - let ident = if allow_wildcards && self.consume_token(&Token::Mul) { + let ident = if allow_wildcards && self.peek_token().token == Token::Mul { + let span = self.next_token().span; Ident { value: Token::Mul.to_string(), quote_style: None, + span, } } else { self.parse_identifier(in_table_clause)? diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 008e5f87d..a8763a46e 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -3018,53 +3018,75 @@ fn parse_bitstring_literal() { #[test] fn parse_grant() { let sql = "GRANT ALL ON *.* TO 'jeffrey'@'%'"; - assert_eq!( - mysql().verified_stmt(sql), - Statement::Grant { - privileges: Privileges::All { + let stmt = mysql().verified_stmt(sql); + if let Statement::Grant { + privileges, + objects, + grantees, + with_grant_option, + granted_by, + } = stmt + { + assert_eq!( + privileges, + Privileges::All { with_privileges_keyword: false - }, - objects: GrantObjects::Tables(vec![ObjectName(vec!["*".into(), "*".into()])]), - grantees: vec![Grantee::UserHost { - user: Ident { - value: "jeffrey".to_owned(), - quote_style: Some('\'') - }, - host: Ident { - value: "%".to_owned(), - quote_style: Some('\'') - } - }], - with_grant_option: false, - granted_by: None + } + ); + assert_eq!( + objects, + GrantObjects::Tables(vec![ObjectName(vec!["*".into(), "*".into()])]) + ); + assert!(!with_grant_option); + assert!(granted_by.is_none()); + if let [Grantee::UserHost { user, host }] = grantees.as_slice() { + assert_eq!(user.value, "jeffrey"); + assert_eq!(user.quote_style, Some('\'')); + assert_eq!(host.value, "%"); + assert_eq!(host.quote_style, Some('\'')); + } else { + unreachable!() } - ) + } else { + unreachable!() + } } #[test] fn parse_revoke() { let sql = "REVOKE ALL ON db1.* FROM 'jeffrey'@'%'"; - assert_eq!( - mysql_and_generic().verified_stmt(sql), - Statement::Revoke { - privileges: Privileges::All { + let stmt = mysql_and_generic().verified_stmt(sql); + if let Statement::Revoke { + privileges, + objects, + grantees, + granted_by, + cascade, + } = stmt + { + assert_eq!( + privileges, + Privileges::All { with_privileges_keyword: false - }, - objects: GrantObjects::Tables(vec![ObjectName(vec!["db1".into(), "*".into()])]), - grantees: vec![Grantee::UserHost { - user: Ident { - value: "jeffrey".to_owned(), - quote_style: Some('\'') - }, - host: Ident { - value: "%".to_owned(), - quote_style: Some('\'') - } - }], - granted_by: None, - cascade: None, - } - ) + } + ); + assert_eq!( + objects, + GrantObjects::Tables(vec![ObjectName(vec!["db1".into(), "*".into()])]) + ); + if let [Grantee::UserHost { user, host }] = grantees.as_slice() { + assert_eq!(user.value, "jeffrey"); + assert_eq!(user.quote_style, Some('\'')); + assert_eq!(host.value, "%"); + assert_eq!(host.quote_style, Some('\'')); + } else { + unreachable!() + } + assert!(granted_by.is_none()); + assert!(cascade.is_none()); + } else { + unreachable!() + } } #[test] @@ -3106,19 +3128,14 @@ fn parse_create_view_definer_param() { } = stmt { assert!(algorithm.is_none()); - assert_eq!( - definer, - Some(Grantee::UserHost { - user: Ident { - value: "jeffrey".to_owned(), - quote_style: Some('\'') - }, - host: Ident { - value: "localhost".to_owned(), - quote_style: Some('\'') - }, - }) - ); + if let Some(Grantee::UserHost { user, host }) = definer { + assert_eq!(user.value, "jeffrey"); + assert_eq!(user.quote_style, Some('\'')); + assert_eq!(host.value, "localhost"); + assert_eq!(host.quote_style, Some('\'')); + } else { + unreachable!() + } assert!(security.is_none()); } else { unreachable!() @@ -3163,19 +3180,14 @@ fn parse_create_view_multiple_params() { } = stmt { assert_eq!(algorithm, Some(CreateViewAlgorithm::Undefined)); - assert_eq!( - definer, - Some(Grantee::UserHost { - user: Ident { - value: "root".to_owned(), - quote_style: Some('`') - }, - host: Ident { - value: "%".to_owned(), - quote_style: Some('`') - }, - }) - ); + if let Some(Grantee::UserHost { user, host }) = definer { + assert_eq!(user.value, "root"); + assert_eq!(user.quote_style, Some('`')); + assert_eq!(host.value, "%"); + assert_eq!(host.quote_style, Some('`')); + } else { + unreachable!() + } assert_eq!(security, Some(CreateViewSecurity::Invoker)); } else { unreachable!()