diff --git a/src/lexer/sql/test.rs b/src/lexer/sql/test.rs index 7119bff..068e2bf 100644 --- a/src/lexer/sql/test.rs +++ b/src/lexer/sql/test.rs @@ -8,8 +8,7 @@ use crate::parser::{ #[test] fn count_placeholders() -> Result<(), Error> { - let sql = "SELECT ? WHERE 1 = ?"; - let mut parser = Parser::new(sql.as_bytes()); + let mut parser = Parser::new(b"SELECT ? WHERE 1 = ?"); let ast = parser.next()?.unwrap(); let mut info = ParameterInfo::default(); ast.to_tokens(&mut info).unwrap(); @@ -19,8 +18,7 @@ fn count_placeholders() -> Result<(), Error> { #[test] fn count_numbered_placeholders() -> Result<(), Error> { - let sql = "SELECT ?1 WHERE 1 = ?2 AND 0 = ?1"; - let mut parser = Parser::new(sql.as_bytes()); + let mut parser = Parser::new(b"SELECT ?1 WHERE 1 = ?2 AND 0 = ?1"); let ast = parser.next()?.unwrap(); let mut info = ParameterInfo::default(); ast.to_tokens(&mut info).unwrap(); @@ -30,8 +28,7 @@ fn count_numbered_placeholders() -> Result<(), Error> { #[test] fn count_unused_placeholders() -> Result<(), Error> { - let sql = "SELECT ?1 WHERE 1 = ?3"; - let mut parser = Parser::new(sql.as_bytes()); + let mut parser = Parser::new(b"SELECT ?1 WHERE 1 = ?3"); let ast = parser.next()?.unwrap(); let mut info = ParameterInfo::default(); ast.to_tokens(&mut info).unwrap(); @@ -41,8 +38,7 @@ fn count_unused_placeholders() -> Result<(), Error> { #[test] fn count_named_placeholders() -> Result<(), Error> { - let sql = "SELECT :x, :y WHERE 1 = :y"; - let mut parser = Parser::new(sql.as_bytes()); + let mut parser = Parser::new(b"SELECT :x, :y WHERE 1 = :y"); let ast = parser.next()?.unwrap(); let mut info = ParameterInfo::default(); ast.to_tokens(&mut info).unwrap(); @@ -55,8 +51,7 @@ fn count_named_placeholders() -> Result<(), Error> { #[test] fn duplicate_column() { - let sql = "CREATE TABLE t (x TEXT, x TEXT)"; - let mut parser = Parser::new(sql.as_bytes()); + let mut parser = Parser::new(b"CREATE TABLE t (x TEXT, x TEXT)"); let r = parser.next(); let Error::ParserError(ParserError::Custom(msg), _) = r.unwrap_err() else { panic!("unexpected error type") @@ -64,6 +59,22 @@ fn duplicate_column() { assert!(msg.contains("duplicate column name")); } +#[test] +fn create_table_without_column() { + let mut parser = Parser::new(b"CREATE TABLE t ()"); + let r = parser.next(); + let Error::ParserError( + ParserError::SyntaxError { + token_type: "RP", + found: None, + }, + _, + ) = r.unwrap_err() + else { + panic!("unexpected error type") + }; +} + #[test] fn vtab_args() -> Result<(), Error> { let sql = r#"CREATE VIRTUAL TABLE mail USING fts3( diff --git a/src/parser/ast/mod.rs b/src/parser/ast/mod.rs index 2be5369..c35d6ac 100644 --- a/src/parser/ast/mod.rs +++ b/src/parser/ast/mod.rs @@ -154,6 +154,26 @@ impl Display for Cmd { } } +impl Cmd { + /// Like `sqlite3_column_count` but more limited + pub fn column_count(&self) -> ColumnCount { + match self { + Cmd::Explain(_) => ColumnCount::Fixed(8), + Cmd::ExplainQueryPlan(_) => ColumnCount::Fixed(4), + Cmd::Stmt(stmt) => stmt.column_count(), + } + } + + /// Like `sqlite3_stmt_readonly` + pub fn readonly(&self) -> bool { + match self { + Cmd::Explain(stmt) => stmt.readonly(), + Cmd::ExplainQueryPlan(stmt) => stmt.readonly(), + Cmd::Stmt(stmt) => stmt.readonly(), + } + } +} + pub(crate) enum ExplainKind { Explain, QueryPlan, @@ -691,6 +711,64 @@ impl ToTokens for Stmt { } } +/// Column count +pub enum ColumnCount { + /// With `SELECT *` / EXPLAIN / PRAGMA + Dynamic, + /// + Fixed(usize), + /// No column + None, +} + +impl ColumnCount { + fn incr(&mut self) { + if let ColumnCount::Fixed(n) = self { + *n += 1; + } + } +} + +impl Stmt { + /// Like `sqlite3_column_count` but more limited + pub fn column_count(&self) -> ColumnCount { + match self { + Stmt::Delete { + returning: Some(returning), + .. + } => column_count(returning), + Stmt::Insert { + returning: Some(returning), + .. + } => column_count(returning), + Stmt::Pragma(..) => ColumnCount::Dynamic, + Stmt::Select(s) => s.body.select.column_count(), + Stmt::Update { + returning: Some(returning), + .. + } => column_count(returning), + _ => ColumnCount::None, + } + } + + /// Like `sqlite3_stmt_readonly` + pub fn readonly(&self) -> bool { + match self { + Stmt::Attach { .. } => true, + Stmt::Begin(..) => true, + Stmt::Commit(..) => true, + Stmt::Detach(..) => true, + Stmt::Pragma(..) => true, // TODO check all + Stmt::Reindex { .. } => true, + Stmt::Release(..) => true, + Stmt::Rollback { .. } => true, + Stmt::Savepoint(..) => true, + Stmt::Select(..) => true, + _ => false, + } + } +} + // https://sqlite.org/syntax/expr.html #[derive(Clone, Debug, PartialEq, Eq)] pub enum Expr { @@ -1440,6 +1518,19 @@ impl ToTokens for OneSelect { } } +impl OneSelect { + /// Like `sqlite3_column_count` but more limited + pub fn column_count(&self) -> ColumnCount { + match self { + OneSelect::Select { columns, .. } => column_count(columns), + OneSelect::Values(values) => { + assert!(!values.is_empty()); // TODO Validate + ColumnCount::Fixed(values[0].len()) + } + } + } +} + // https://sqlite.org/syntax/join-clause.html #[derive(Clone, Debug, PartialEq, Eq)] pub struct FromClause { @@ -1538,6 +1629,26 @@ impl ToTokens for ResultColumn { } } +impl ResultColumn { + fn column_count(&self) -> ColumnCount { + match self { + ResultColumn::Expr(..) => ColumnCount::Fixed(1), + _ => ColumnCount::Dynamic, + } + } +} +fn column_count(cols: &[ResultColumn]) -> ColumnCount { + assert!(!cols.is_empty()); + let mut count = ColumnCount::Fixed(0); + for col in cols { + match col.column_count() { + ColumnCount::Fixed(_) => count.incr(), + _ => return ColumnCount::Dynamic, + } + } + count +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum As { As(Name),