diff --git a/src/ast.rs b/src/ast.rs index f788dd3..6fbdf8e 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -7,7 +7,7 @@ //! - [`EmitDefault`]: Emit the AST as BigQuery SQL, extremely close to the //! original input. This is optional, and only required if that type's //! [`Emit`] implementation wants to use it. -//! - [`Drive`]: Provided by the [`derive-visitor` +//! - [`Drive`] and [`DriveMut`]: Provided by the [`derive-visitor` //! crate](https://github.com/nikis05/derive-visitor). This provides an API to //! traverse the AST generically, using the [`derive_visitor::Visitor`] trait. //! This is honestly deep Rust magic, but it prevents us from needing to @@ -27,7 +27,7 @@ use codespan_reporting::{ diagnostic::{Diagnostic, Label}, files::SimpleFile, }; -use derive_visitor::Drive; +use derive_visitor::{Drive, DriveMut}; use joinery_macros::{Emit, EmitDefault}; use crate::{ @@ -188,7 +188,7 @@ impl Emit for Vec { } /// Our basic token type. This is used for all punctuation and keywords. -#[derive(Clone, Debug, Drive)] +#[derive(Clone, Debug, Drive, DriveMut)] pub struct Token { #[drive(skip)] pub span: Span, @@ -342,6 +342,15 @@ impl NodeVec { } } +impl Clone for NodeVec { + fn clone(&self) -> Self { + NodeVec { + nodes: self.nodes.clone(), + separators: self.separators.clone(), + } + } +} + impl Default for NodeVec { fn default() -> Self { NodeVec { @@ -360,6 +369,16 @@ impl<'a, T: fmt::Debug> IntoIterator for &'a NodeVec { } } +// Mutable iterator for `DriveMut`. +impl<'a, T: fmt::Debug> IntoIterator for &'a mut NodeVec { + type Item = &'a mut T; + type IntoIter = std::slice::IterMut<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.nodes.iter_mut() + } +} + impl Emit for NodeVec { fn emit(&self, t: Target, f: &mut fmt::Formatter<'_>) -> fmt::Result { for (i, node) in self.nodes.iter().enumerate() { @@ -384,7 +403,7 @@ impl Emit for NodeVec { } /// An identifier, such as a column name. -#[derive(Clone, Debug, Drive)] +#[derive(Clone, Debug, Drive, DriveMut)] pub struct Identifier { /// Our original token. pub token: Token, @@ -450,7 +469,7 @@ impl Emit for Identifier { } /// A table name. -#[derive(Clone, Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub enum TableName { ProjectDatasetTable { project: Identifier, @@ -505,7 +524,7 @@ impl Emit for TableName { } /// A table and a column name. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct TableAndColumnName { pub table_name: TableName, pub dot: Token, @@ -513,7 +532,7 @@ pub struct TableAndColumnName { } /// An entire SQL program. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct SqlProgram { /// Any whitespace that appears before the first statement. This is represented /// as a token with an empty `token_str()`. @@ -525,7 +544,7 @@ pub struct SqlProgram { } /// A statement in our abstract syntax tree. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum Statement { Query(QueryStatement), DeleteFrom(DeleteFromStatement), @@ -536,7 +555,7 @@ pub enum Statement { } /// A query statement. This exists mainly because it's in the official grammar. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct QueryStatement { pub query_expression: QueryExpression, } @@ -547,7 +566,7 @@ pub struct QueryStatement { /// /// [official grammar]: /// https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#sql_syntax. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub enum QueryExpression { SelectExpression(SelectExpression), Nested { @@ -588,7 +607,7 @@ impl Emit for QueryExpression { } /// Common table expressions (CTEs). -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct CommonTableExpression { pub name: Identifier, pub as_token: Token, @@ -598,7 +617,7 @@ pub struct CommonTableExpression { } /// Set operators. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub enum SetOperator { UnionAll { union_token: Token, @@ -645,7 +664,7 @@ impl Emit for SetOperator { } /// A `SELECT` expression. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct SelectExpression { pub select_options: SelectOptions, pub select_list: SelectList, @@ -661,14 +680,14 @@ pub struct SelectExpression { } /// The head of a `SELECT`, including any modifiers. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct SelectOptions { pub select_token: Token, pub distinct: Option, } /// The `DISTINCT` modifier. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct Distinct { pub distinct_token: Token, } @@ -691,13 +710,13 @@ pub struct Distinct { /// select_expression: /// expression [ [ AS ] alias ] /// ``` -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct SelectList { pub items: NodeVec, } /// A single item in a `SELECT` list. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum SelectListItem { /// An expression, optionally with an alias. Expression { @@ -716,7 +735,7 @@ pub enum SelectListItem { } /// An `EXCEPT` clause. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct Except { pub except_token: Token, pub paren1: Token, @@ -740,7 +759,7 @@ impl Emit for Except { } /// An SQL expression. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub enum Expression { Literal { token: Token, @@ -894,7 +913,7 @@ impl Emit for Expression { } /// A literal value. -#[derive(Debug, Drive)] +#[derive(Clone, Debug, Drive, DriveMut)] pub enum LiteralValue { Bool(#[drive(skip)] bool), Int64(#[drive(skip)] i64), @@ -903,7 +922,7 @@ pub enum LiteralValue { } /// An `INTERVAL` expression. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct IntervalExpression { pub interval_token: Token, pub number: Token, // Not even bothering to parse this for now. @@ -926,7 +945,7 @@ impl Emit for IntervalExpression { } /// A date part in an `INTERVAL` expression. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct DatePart { pub date_part_token: Token, } @@ -944,7 +963,7 @@ impl Emit for DatePart { } /// A cast expression. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct Cast { cast_token: Token, paren1: Token, @@ -977,7 +996,7 @@ impl Emit for Cast { /// /// [official grammar]: /// https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#in_operators -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum InValueSet { QueryExpression { paren1: Token, @@ -1002,7 +1021,7 @@ pub enum InValueSet { /// /// Not all combinations of our fields are valid. For example, we can't have /// a missing `ARRAY` and a `delim1` of `(`. We'll let the parser handle that. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct ArrayExpression { pub array_token: Option, pub element_type: Option, @@ -1035,7 +1054,7 @@ impl Emit for ArrayExpression { /// An `ARRAY` definition. Either a `SELECT` expression or a list of /// expressions. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub enum ArrayDefinition { Query(Box), Elements(NodeVec), @@ -1056,7 +1075,7 @@ impl Emit for ArrayDefinition { } /// A struct expression. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct StructExpression { pub struct_token: Token, pub paren1: Token, @@ -1080,7 +1099,7 @@ impl Emit for StructExpression { } /// The type of the elements in an `ARRAY` expression. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct ArrayElementType { pub lt: Token, pub elem_type: DataType, @@ -1088,7 +1107,7 @@ pub struct ArrayElementType { } /// A `COUNT` expression. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum CountExpression { CountStar { count_token: Token, @@ -1106,7 +1125,7 @@ pub enum CountExpression { } /// A `CASE WHEN` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct CaseWhenClause { pub when_token: Token, pub condition: Box, @@ -1115,7 +1134,7 @@ pub struct CaseWhenClause { } /// A `CASE ELSE` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct CaseElseClause { pub else_token: Token, pub result: Box, @@ -1123,7 +1142,7 @@ pub struct CaseElseClause { /// `CURRENT_DATE` may appear as either `CURRENT_DATE` or `CURRENT_DATE()`. /// And different databases seem to support one or the other or both. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct CurrentDate { pub current_date_token: Token, pub empty_parens: Option, @@ -1140,7 +1159,7 @@ impl Emit for CurrentDate { } /// An empty `()` expression. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct EmptyParens { pub paren1: Token, pub paren2: Token, @@ -1149,7 +1168,7 @@ pub struct EmptyParens { /// Special "functions" that manipulate dates. These all take a [`DatePart`] /// as a final argument. So in Lisp sense, these are special forms or macros, /// not ordinary function calls. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct SpecialDateFunctionCall { pub function_name: Identifier, pub paren1: Token, @@ -1158,14 +1177,14 @@ pub struct SpecialDateFunctionCall { } /// An expression or a date part. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum ExpressionOrDatePart { Expression(Expression), DatePart(DatePart), } /// A function call. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct FunctionCall { pub name: FunctionName, pub paren1: Token, @@ -1175,7 +1194,7 @@ pub struct FunctionCall { } /// A function name. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub enum FunctionName { ProjectDatasetFunction { project: Identifier, @@ -1195,6 +1214,14 @@ pub enum FunctionName { } impl FunctionName { + pub fn function_identifier(&self) -> &Identifier { + match self { + FunctionName::ProjectDatasetFunction { function, .. } + | FunctionName::DatasetFunction { function, .. } + | FunctionName::Function { function } => function, + } + } + /// Get the unescaped function name, in the original BigQuery form. pub fn unescaped_bigquery(&self) -> String { match self { @@ -1219,11 +1246,7 @@ impl Emit for FunctionName { match t { Target::SQLite3 => { let name = self.unescaped_bigquery(); - let ws = match self { - FunctionName::ProjectDatasetFunction { function, .. } - | FunctionName::DatasetFunction { function, .. } - | FunctionName::Function { function } => function.token.ws_only(), - }; + let ws = self.function_identifier().token.ws_only(); write!(f, "{}{}", SQLite3Ident(&name), t.f(&ws)) } _ => self.emit_default(t, f), @@ -1236,7 +1259,7 @@ impl Emit for FunctionName { /// See the [official grammar][]. We only implement part of this. /// /// [official grammar]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls#syntax -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct OverClause { pub over_token: Token, pub paren1: Token, @@ -1247,7 +1270,7 @@ pub struct OverClause { } /// A `PARTITION BY` clause for a window function. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct PartitionBy { pub partition_token: Token, pub by_token: Token, @@ -1255,7 +1278,7 @@ pub struct PartitionBy { } /// An `ORDER BY` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct OrderBy { pub order_token: Token, pub by_token: Token, @@ -1263,28 +1286,28 @@ pub struct OrderBy { } /// An item in an `ORDER BY` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct OrderByItem { pub expression: Expression, pub asc_desc: Option, } /// An `ASC` or `DESC` modifier. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct AscDesc { direction: Token, nulls_clause: Option, } /// A `NULLS FIRST` or `NULLS LAST` modifier. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct NullsClause { nulls_token: Token, first_last_token: Token, } /// A `LIMIT` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct Limit { pub limit_token: Token, pub value: Box, @@ -1295,14 +1318,14 @@ pub struct Limit { /// See the [official grammar][]. We only implement part of this. /// /// [official grammar]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls#def_window_frame -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct WindowFrame { pub rows_token: Token, pub definition: WindowFrameDefinition, } /// A window frame definition. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum WindowFrameDefinition { Start(WindowFrameStart), Between { @@ -1314,7 +1337,7 @@ pub enum WindowFrameDefinition { } /// A window frame start. Keep this simple for now. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum WindowFrameStart { UnboundedPreceding { unbounded_token: Token, @@ -1323,7 +1346,7 @@ pub enum WindowFrameStart { } /// A window frame end. Keep this simple for now. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum WindowFrameEnd { CurrentRow { current_token: Token, @@ -1332,7 +1355,7 @@ pub enum WindowFrameEnd { } /// Data types. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub enum DataType { Bool(Token), Bytes(Token), @@ -1413,14 +1436,14 @@ impl Emit for DataType { } /// A field in a `STRUCT` type. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct StructField { pub name: Option, pub data_type: DataType, } /// An array index expression. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct IndexExpression { pub expression: Box, pub bracket1: Token, @@ -1445,7 +1468,7 @@ impl Emit for IndexExpression { } /// Different ways to index arrays. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub enum IndexOffset { Simple(Box), Offset { @@ -1484,14 +1507,14 @@ impl Emit for IndexOffset { } /// An `AS` alias. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct Alias { pub as_token: Option, pub ident: Identifier, } /// The `FROM` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct FromClause { pub from_token: Token, pub from_item: FromItem, @@ -1499,7 +1522,7 @@ pub struct FromClause { } /// Items which may appear in a `FROM` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum FromItem { /// A table name, optionally with an alias. TableName { @@ -1525,7 +1548,7 @@ pub enum FromItem { } /// A join operation. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum JoinOperation { /// A `JOIN` clause. ConditionJoin { @@ -1543,7 +1566,7 @@ pub enum JoinOperation { } /// The type of a join. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum JoinType { Inner { inner_token: Option, @@ -1563,7 +1586,7 @@ pub enum JoinType { } /// The condition used for a `JOIN`. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum ConditionJoinOperator { Using { using_token: Token, @@ -1578,14 +1601,14 @@ pub enum ConditionJoinOperator { } /// A `WHERE` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct WhereClause { pub where_token: Token, pub expression: Expression, } /// A `GROUP BY` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct GroupBy { pub group_token: Token, pub by_token: Token, @@ -1593,14 +1616,14 @@ pub struct GroupBy { } /// A `HAVING` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct Having { pub having_token: Token, pub expression: Expression, } /// A `QUALIFY` clause. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct Qualify { pub qualify_token: Token, pub expression: Expression, @@ -1620,7 +1643,7 @@ impl Emit for Qualify { } /// A `DELETE FROM` statement. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct DeleteFromStatement { pub delete_token: Token, pub from_token: Token, @@ -1630,7 +1653,7 @@ pub struct DeleteFromStatement { } /// A `INSERT INTO` statement. We only support the `SELECT` version. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct InsertIntoStatement { pub insert_token: Token, pub into_token: Token, @@ -1639,7 +1662,7 @@ pub struct InsertIntoStatement { } /// The data to be inserted into a table. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum InsertedData { /// A `SELECT` statement. Select { query: QueryExpression }, @@ -1651,7 +1674,7 @@ pub enum InsertedData { } /// A row in a `VALUES` clause. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct Row { pub paren1: Token, pub expressions: NodeVec, @@ -1659,7 +1682,7 @@ pub struct Row { } /// A `CREATE TABLE` statement. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct CreateTableStatement { pub create_token: Token, pub or_replace: Option, @@ -1683,7 +1706,7 @@ impl Emit for CreateTableStatement { } /// A `CREATE VIEW` statement. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct CreateViewStatement { pub create_token: Token, pub or_replace: Option, @@ -1708,7 +1731,7 @@ impl Emit for CreateViewStatement { } /// The `OR REPLACE` modifier. -#[derive(Debug, Drive, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] pub struct OrReplace { pub or_token: Token, pub replace_token: Token, @@ -1724,13 +1747,13 @@ impl Emit for OrReplace { } /// The `TEMPORARY` modifier. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct Temporary { pub temporary_token: Token, } /// The part of a `CREATE TABLE` statement that defines the columns. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub enum CreateTableDefinition { /// ( column_definition [, ...] ) Columns { @@ -1746,14 +1769,14 @@ pub enum CreateTableDefinition { } /// A column definition. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct ColumnDefinition { pub name: Identifier, pub data_type: DataType, } /// A `DROP VIEW` statement. -#[derive(Debug, Drive, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct DropViewStatement { pub drop_token: Token, pub view_token: Token, diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index ed41d7a..fb6580a 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -1,11 +1,11 @@ //! Database drivers. -use std::{collections::VecDeque, fmt, str::FromStr}; +use std::{borrow::Cow, collections::VecDeque, fmt, str::FromStr}; use async_trait::async_trait; use crate::{ - ast::{Emit, Target}, + ast::{self, Emit, Target}, errors::{format_err, Error, Result}, }; @@ -67,12 +67,26 @@ pub trait Driver: Send + Sync + 'static { /// Execute a query represented as an AST. This can execute multiple /// statements. - async fn execute_ast(&mut self, ast: &crate::ast::SqlProgram) -> Result<()> { - for statement in &ast.statements { - let sql = statement.emit_to_string(self.target()); + async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> { + let rewritten = self.rewrite_ast(ast)?; + for sql in rewritten.extra_native_sql { self.execute_native_sql_statement(&sql).await?; } - Ok(()) + let sql = rewritten.ast.emit_to_string(self.target()); + self.execute_native_sql_statement(&sql).await + } + + /// Rewrite an AST to convert function names, etc., into versions that can + /// be passed to [`Emitted::emit_to_string`] for this database. This allows + /// us to do less database-specific work in [`Emit::emit`], and more in the + /// database drivers themselves. This can't change lexical syntax, but it + /// can change the structure of the AST. + fn rewrite_ast<'ast>(&self, ast: &'ast ast::SqlProgram) -> Result> { + // Default implementation does nothing. + Ok(RewrittenAst { + extra_native_sql: vec![], + ast: Cow::Borrowed(ast), + }) } /// Drop a table if it exists. @@ -85,6 +99,16 @@ pub trait Driver: Send + Sync + 'static { async fn compare_tables(&mut self, result_table: &str, expected_table: &str) -> Result<()>; } +/// The output of [`Driver::rewrite_ast`]. +pub struct RewrittenAst<'a> { + /// Extra native SQL statements to execute before the AST. Probably + /// temporary UDFs and things like that. + pub extra_native_sql: Vec, + + /// The new AST. + pub ast: Cow<'a, ast::SqlProgram>, +} + /// Extensions to [`Driver`] that are not ["object safe"][safe]. /// /// [safe]: https://doc.rust-lang.org/reference/items/traits.html#object-safety diff --git a/src/drivers/snowflake/mod.rs b/src/drivers/snowflake/mod.rs index 144c424..c82aa79 100644 --- a/src/drivers/snowflake/mod.rs +++ b/src/drivers/snowflake/mod.rs @@ -1,21 +1,24 @@ //! A Snowflake driver. -use std::{env, fmt, str::FromStr}; +use std::{borrow::Cow, env, fmt, str::FromStr}; use arrow_json::writer::record_batches_to_json_rows; use async_trait::async_trait; +use derive_visitor::DriveMut; use once_cell::sync::Lazy; use regex::Regex; use serde_json::Value; use snowflake_api::{QueryResult, SnowflakeApi}; -use tracing::instrument; +use tracing::{debug, instrument}; use crate::{ - ast::Target, + ast::{self, Emit, Target}, errors::{format_err, Context, Error, Result}, }; -use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator}; +use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator, RewrittenAst}; + +mod rename_functions; /// Locator prefix for Snowflake. pub const SNOWFLAKE_LOCATOR_PREFIX: &str = "snowflake:"; @@ -190,6 +193,7 @@ impl Driver for SnowflakeDriver { #[instrument(skip(self, sql), err)] async fn execute_native_sql_statement(&mut self, sql: &str) -> Result<()> { + debug!(%sql, "executing SQL"); self.connection .exec(sql) .await @@ -197,6 +201,40 @@ impl Driver for SnowflakeDriver { Ok(()) } + async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> { + let rewritten = self.rewrite_ast(ast)?; + for sql in rewritten.extra_native_sql { + self.execute_native_sql_statement(&sql).await?; + } + + // We can only execute one statement at a time. + for statement in &rewritten.ast.statements { + let sql = statement.emit_to_string(self.target()); + self.execute_native_sql_statement(&sql).await?; + } + + // Reset session to drop `TEMP` tables and UDFs. + self.connection + .close_session() + .await + .context("could not end Snowflake session") + } + + fn rewrite_ast<'ast>(&self, ast: &'ast ast::SqlProgram) -> Result> { + let mut ast = ast.clone(); + let mut renamer = rename_functions::RenameFunctions::default(); + ast.drive_mut(&mut renamer); + let extra_native_sql = renamer + .udfs + .values() + .map(|udf| udf.to_sql()) + .collect::>(); + Ok(RewrittenAst { + extra_native_sql, + ast: Cow::Owned(ast), + }) + } + #[instrument(skip(self))] async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> { self.execute_native_sql_statement(&format!( diff --git a/src/drivers/snowflake/rename_functions.rs b/src/drivers/snowflake/rename_functions.rs new file mode 100644 index 0000000..a14ff4f --- /dev/null +++ b/src/drivers/snowflake/rename_functions.rs @@ -0,0 +1,65 @@ +//! A simple tree-walker that renames functions to their Snowflake equivalents. + +use std::collections::HashMap; + +use derive_visitor::VisitorMut; + +use crate::ast::{FunctionName, Identifier}; + +// A `phf_map!` of BigQuery function names to Snowflake function names. Use +// this for simple renaming. +static FUNCTION_NAMES: phf::Map<&'static str, &'static str> = phf::phf_map! { + "REGEXP_EXTRACT" => "REGEXP_SUBSTR", + "SHA256" => "SHA2_BINARY", // Second argument defaults to SHA256. +}; + +/// A Snowflake UDF (user-defined function). +pub struct Udf { + pub decl: &'static str, + pub sql: &'static str, +} + +impl Udf { + /// Generate the SQL to create this UDF. + pub fn to_sql(&self) -> String { + format!( + "CREATE OR REPLACE TEMP FUNCTION {} AS $$\n{}\n$$\n", + self.decl, self.sql + ) + } +} + +/// A `phf_map!` of BigQuery UDF names to Snowflake UDFs. Use this when we +/// actually need to create a UDF as a helper function. +static UDFS: phf::Map<&'static str, &'static Udf> = phf::phf_map! { + "TO_HEX" => &Udf { decl: "TO_HEX(b BINARY) RETURNS STRING", sql: "HEX_ENCODE(b, 0)" }, +}; + +#[derive(Default, VisitorMut)] +#[visitor(FunctionName(enter))] +pub struct RenameFunctions { + // UDFs that we need to create, if we haven't already. + pub udfs: HashMap, +} + +impl RenameFunctions { + fn enter_function_name(&mut self, function_name: &mut FunctionName) { + if let FunctionName::Function { function } = function_name { + let name = function.unescaped_bigquery().to_ascii_uppercase(); + if let Some(snowflake_name) = FUNCTION_NAMES.get(&name) { + // Rename the function. + let orig_ident = function_name.function_identifier(); + *function_name = FunctionName::Function { + function: Identifier { + token: orig_ident.token.with_token_str(snowflake_name), + text: snowflake_name.to_string(), + }, + }; + } else if let Some(udf) = UDFS.get(&name) { + // We'll need a UDF, so add it to our list it if isn't already + // there. + self.udfs.insert(name, udf); + } + } + } +} diff --git a/src/drivers/sqlite3/mod.rs b/src/drivers/sqlite3/mod.rs index 9a94f72..29fb35c 100644 --- a/src/drivers/sqlite3/mod.rs +++ b/src/drivers/sqlite3/mod.rs @@ -4,7 +4,10 @@ use std::{fmt, str::FromStr, vec}; use async_rusqlite::Connection; use async_trait::async_trait; -use rusqlite::{functions::FunctionFlags, types}; +use rusqlite::{ + functions::{self, FunctionFlags}, + types, +}; use crate::{ ast::Target, @@ -102,6 +105,10 @@ impl SQLite3Driver { // yet, but it allows us to parse and execute queries that use `UNNEST`. register_unnest(conn).expect("failed to register UNNEST"); + // Install real functions, where we have them. + conn.create_scalar_function("concat", -1, FunctionFlags::SQLITE_UTF8, func_concat) + .expect("failed to create concat function"); + // Install some dummy functions that always return NULL. let dummy_fns = &[ ("array", -1), @@ -159,6 +166,37 @@ impl Driver for SQLite3Driver { } } +/// Concatenate a list of values into a string. Mimics BigQuery's `CONCAT`. +fn func_concat(ctx: &functions::Context<'_>) -> rusqlite::Result { + let mut result = String::new(); + for idx in 0..ctx.len() { + match ctx.get_raw(idx) { + types::ValueRef::Null => return Ok(types::Value::Null), + types::ValueRef::Integer(i) => { + result.push_str(&i.to_string()); + } + types::ValueRef::Real(r) => { + result.push_str(&r.to_string()); + } + types::ValueRef::Text(s) => { + let s = String::from_utf8(s.to_owned()).map_err(|_| { + rusqlite::Error::InvalidFunctionParameterType(idx, types::Type::Text) + })?; + result.push_str(&s); + } + types::ValueRef::Blob(_) => { + // `CONCAT` should also support being called with _all_ `BINARY` + // values, but you can't mix them with other types. + return Err(rusqlite::Error::InvalidFunctionParameterType( + idx, + types::Type::Blob, + )); + } + } + } + Ok(types::Value::Text(result)) +} + #[async_trait] impl DriverImpl for SQLite3Driver { type Type = String; diff --git a/tests/sql/functions/README.md b/tests/sql/functions/README.md index a7a3b2b..6f477c4 100644 --- a/tests/sql/functions/README.md +++ b/tests/sql/functions/README.md @@ -9,22 +9,22 @@ Here is a list of functions that are high priorities to implement. You can generate your own version of this list by running `joinery parse --count-function-calls queries.csv`. -- [ ] REGEXP_REPLACE(_,_,_) -- [ ] REGEXP_EXTRACT(_,_) -- [ ] COALESCE(*) -- [ ] LOWER(_) -- [ ] TO_HEX(_) -- [ ] SHA256(_) +- [x] REGEXP_REPLACE(_,_,_) +- [x] REGEXP_EXTRACT(_,_) +- [x] COALESCE(*) +- [x] LOWER(_) +- [x] TO_HEX(_) +- [x] SHA256(_) - [ ] LENGTH(_) -- [ ] CONCAT(*) -- [ ] TRIM(_) +- [x] CONCAT(*) +- [x] TRIM(_) - [ ] ARRAY_TO_STRING(_,_) - [ ] SUM(_) - [ ] FARM_FINGERPRINT(_) - [ ] ANY_VALUE(_) - [ ] ROW_NUMBER() OVER(..) - [ ] COUNTIF(_) -- [ ] UPPER(_) +- [x] UPPER(_) - [ ] ARRAY_AGG(_) - [ ] DATE_TRUNC(_,_) (special) - [ ] MIN(_) diff --git a/tests/sql/functions/simple/coalesce.sql b/tests/sql/functions/simple/coalesce.sql new file mode 100644 index 0000000..e8d766c --- /dev/null +++ b/tests/sql/functions/simple/coalesce.sql @@ -0,0 +1,24 @@ +-- COALESCE + +CREATE OR REPLACE TABLE __result1 AS +SELECT + -- This is not supported on BigQuery. + -- + -- COALESCE() AS coalesce_empty, + -- + -- This is supported by BigQuery, but not by other databases. We can remove + -- it in the transpiler if absolutely necessary, but it's probably better to + -- make it an error. + -- + -- COALESCE(1) AS coalesce_one, + COALESCE(1, 2) AS coalesce_two, + COALESCE(NULL, 2) AS coalesce_two_null, + COALESCE(NULL, 2, 3) AS coalesce_three_null; + +CREATE OR REPLACE TABLE __expected1 ( + coalesce_two INT64, + coalesce_two_null INT64, + coalesce_three_null INT64, +); +INSERT INTO __expected1 VALUES + (1, 2, 2); diff --git a/tests/sql/functions/simple/regexp.sql b/tests/sql/functions/simple/regexp.sql index 518153f..d9f99cb 100644 --- a/tests/sql/functions/simple/regexp.sql +++ b/tests/sql/functions/simple/regexp.sql @@ -1,4 +1,3 @@ --- pending: snowflake REGEX_EXTRACT needs to be wrapped with REGEXP_SUBSTR -- pending: sqlite3 No regex fuctions -- -- REGEXP_REPLACE, REGEXP_EXTRACT diff --git a/tests/sql/functions/simple/sha256.sql b/tests/sql/functions/simple/sha256.sql new file mode 100644 index 0000000..a4b9d78 --- /dev/null +++ b/tests/sql/functions/simple/sha256.sql @@ -0,0 +1,13 @@ +-- pending: sqlite3 No SHA256 function + +CREATE OR REPLACE TABLE __result1 AS +SELECT + to_hex(sha256('hi')) AS hi, + to_hex(sha256(null)) AS null_arg; + +CREATE OR REPLACE TABLE __expected1 ( + hi STRING, + null_arg STRING, +); +INSERT INTO __expected1 VALUES + ('8f434346648f6b96df89dda901c5176b10a6d83961dd3c1ac88b59b2dc327aa4', NULL); diff --git a/tests/sql/functions/simple/strings.sql b/tests/sql/functions/simple/strings.sql new file mode 100644 index 0000000..521cae1 --- /dev/null +++ b/tests/sql/functions/simple/strings.sql @@ -0,0 +1,21 @@ +-- LOWER, UPPER, TRIM, CONCAT +CREATE OR REPLACE TABLE __result1 AS +SELECT + LOWER('FOO') AS lower, + UPPER('foo') AS upper, + TRIM(' foo ') AS trim, + CONCAT('foo', 'bar') AS concat, + CONCAT('x', 1) concat_casted, + CONCAT('x', NULL) concat_null; + +CREATE OR REPLACE TABLE __expected1 ( + lower STRING, + upper STRING, + trim STRING, + concat STRING, + concat_casted STRING, + concat_null STRING, +); +INSERT INTO __expected1 VALUES + ('foo', 'FOO', 'foo', 'foobar', 'x1', NULL); +