diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index d95d7c70a..a732aa5a0 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -633,6 +633,11 @@ pub trait Dialect: Debug + Any { fn supports_comment_on(&self) -> bool { false } + + /// Returns true if the dialect supports the `CREATE TABLE SELECT` statement + fn supports_create_table_select(&self) -> bool { + false + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index d1bf33345..197ce48d4 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -97,6 +97,11 @@ impl Dialect for MySqlDialect { fn supports_limit_comma(&self) -> bool { true } + + /// see + fn supports_create_table_select(&self) -> bool { + true + } } /// `LOCK TABLES` diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 756f4d68b..4115bbc9a 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5990,6 +5990,11 @@ impl<'a> Parser<'a> { // Parse optional `AS ( query )` let query = if self.parse_keyword(Keyword::AS) { Some(self.parse_query()?) + } else if self.dialect.supports_create_table_select() && self.parse_keyword(Keyword::SELECT) + { + // rewind the SELECT keyword + self.prev_token(); + Some(self.parse_query()?) } else { None }; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 25bf306ad..daf65edf1 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -6501,7 +6501,17 @@ fn parse_multiple_statements() { ); test_with("DELETE FROM foo", "SELECT", " bar"); test_with("INSERT INTO foo VALUES (1)", "SELECT", " bar"); - test_with("CREATE TABLE foo (baz INT)", "SELECT", " bar"); + // Since MySQL supports the `CREATE TABLE SELECT` syntax, this needs to be handled separately + let res = parse_sql_statements("CREATE TABLE foo (baz INT); SELECT bar"); + assert_eq!( + vec![ + one_statement_parses_to("CREATE TABLE foo (baz INT)", ""), + one_statement_parses_to("SELECT bar", ""), + ], + res.unwrap() + ); + // Check that extra semicolon at the end is stripped by normalization: + one_statement_parses_to("CREATE TABLE foo (baz INT);", "CREATE TABLE foo (baz INT)"); // Make sure that empty statements do not cause an error: let res = parse_sql_statements(";;"); assert_eq!(0, res.unwrap().len()); @@ -11717,3 +11727,24 @@ fn parse_comments() { ParserError::ParserError("Expected: comment object_type, found: UNKNOWN".to_string()) ); } + +#[test] +fn parse_create_table_select() { + let dialects = all_dialects_where(|d| d.supports_create_table_select()); + let sql_1 = r#"CREATE TABLE foo (baz INT) SELECT bar"#; + let expected = r#"CREATE TABLE foo (baz INT) AS SELECT bar"#; + let _ = dialects.one_statement_parses_to(sql_1, expected); + + let sql_2 = r#"CREATE TABLE foo (baz INT, name STRING) SELECT bar, oth_name FROM test.table_a"#; + let expected = + r#"CREATE TABLE foo (baz INT, name STRING) AS SELECT bar, oth_name FROM test.table_a"#; + let _ = dialects.one_statement_parses_to(sql_2, expected); + + let dialects = all_dialects_where(|d| !d.supports_create_table_select()); + for sql in [sql_1, sql_2] { + assert_eq!( + dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected: end of statement, found: SELECT".to_string()) + ); + } +}