From 0e92a3a408da794402eabd20d473142ea3fce5e3 Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Wed, 18 Oct 2023 22:16:24 -0400 Subject: [PATCH] trino: Work around TEMP, fix many other things --- src/ast.rs | 21 ++++- src/drivers/mod.rs | 77 ++++++++++++++++--- src/drivers/snowflake/mod.rs | 21 +---- src/drivers/trino/mod.rs | 59 ++++++++++---- src/transforms/clean_up_temp_manually.rs | 38 +++++++++ src/transforms/mod.rs | 23 +++++- .../or_replace_to_drop_if_exists.rs | 6 +- src/transforms/rename_functions.rs | 9 ++- tests/sql/data_types/literal_scalars.sql | 9 ++- .../functions/aggregate/approx_quantiles.sql | 1 + .../sql/functions/simple/farm_fingerprint.sql | 1 + tests/sql/pending/date_functions.sql | 3 +- tests/sql/pending/structs.sql | 1 + 13 files changed, 213 insertions(+), 56 deletions(-) create mode 100644 src/transforms/clean_up_temp_manually.rs diff --git a/src/ast.rs b/src/ast.rs index 056d84b..e667f03 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -36,6 +36,7 @@ use crate::{ bigquery::BigQueryName, snowflake::SnowflakeString, sqlite3::{SQLite3Ident, SQLite3String}, + trino::TrinoString, }, errors::{Result, SourceError}, util::is_c_ident, @@ -422,6 +423,14 @@ impl NodeVec { }) } + /// Iterate over just the nodes in this [`NodeVec`], mutably. + pub fn node_iter_mut(&mut self) -> impl Iterator { + self.items.iter_mut().filter_map(|item| match item { + NodeOrSep::Node(node) => Some(node), + NodeOrSep::Sep(_) => None, + }) + } + /// Iterate over nodes and separators separately. Used internally for /// parsing dotted names. fn into_node_and_sep_iters(self) -> (impl Iterator, impl Iterator) { @@ -986,6 +995,14 @@ impl Emit for Expression { SQLite3String(s).fmt(f)?; token.ws_only().emit(t, f) } + // SQLite3 quotes strings differently. + Expression::Literal { + token, + value: LiteralValue::String(s), + } if t == Target::Trino => { + TrinoString(s).fmt(f)?; + token.ws_only().emit(t, f) + } Expression::If { if_token, condition, @@ -1097,7 +1114,9 @@ pub enum CastType { impl Emit for CastType { fn emit(&self, t: Target, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - CastType::SafeCast { safe_cast_token } if t == Target::Snowflake => { + CastType::SafeCast { safe_cast_token } + if t == Target::Snowflake || t == Target::Trino => + { safe_cast_token.with_token_str("TRY_CAST").emit(t, f) } // TODO: This isn't strictly right, but it's as close as I know how to diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index 36fbc6a..60b5b3e 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -3,11 +3,12 @@ use std::{borrow::Cow, collections::VecDeque, fmt, str::FromStr}; use async_trait::async_trait; +use tracing::debug; use crate::{ ast::{self, Emit, Target}, errors::{format_err, Error, Result}, - transforms::Transform, + transforms::{Transform, TransformExtra}, }; use self::{ @@ -66,20 +67,72 @@ pub trait Driver: Send + Sync + 'static { fn target(&self) -> Target; /// Execute a single SQL statement, using native SQL for this database. This - /// is only guaranteed to work if passed a single statement, although some - /// drivers may support multiple statements. Resources created using `CREATE - /// TEMP TABLE`, etc., may not persist across calls. + /// is only guaranteed to work if passed a single statement, unless + /// [`Driver::supports_multiple_statements`] returns `true`. Resources + /// created using `CREATE TEMP TABLE`, etc., may not persist across calls. async fn execute_native_sql_statement(&mut self, sql: &str) -> Result<()>; + /// Does this driver support multiple statements in a single call to + /// [`execute_native_sql_statement`]? + fn supports_multiple_statements(&self) -> bool { + false + } + /// Execute a query represented as an AST. This can execute multiple /// statements. async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> { let rewritten = self.rewrite_ast(ast)?; - for sql in rewritten.extra_native_sql { - self.execute_native_sql_statement(&sql).await?; + self.execute_setup_sql(&rewritten).await?; + let result = if self.supports_multiple_statements() { + self.execute_ast_together(&rewritten).await + } else { + self.execute_ast_separately(&rewritten).await + }; + self.execute_teardown_sql(&rewritten, result.is_ok()) + .await?; + result + } + + /// Execute the setup SQL for this AST. + async fn execute_setup_sql(&mut self, rewritten: &RewrittenAst) -> Result<()> { + for sql in &rewritten.extra.native_setup_sql { + self.execute_native_sql_statement(sql).await?; } + Ok(()) + } + + /// Execute the AST as a single SQL string. + async fn execute_ast_together(&mut self, rewritten: &RewrittenAst) -> Result<()> { let sql = rewritten.ast.emit_to_string(self.target()); - self.execute_native_sql_statement(&sql).await + self.execute_native_sql_statement(&sql).await?; + Ok(()) + } + + /// Execute the AST as individual SQL statements. + async fn execute_ast_separately(&mut self, rewritten: &RewrittenAst) -> Result<()> { + for statement in rewritten.ast.statements.node_iter() { + let sql = statement.emit_to_string(self.target()); + self.execute_native_sql_statement(&sql).await?; + } + Ok(()) + } + + /// Execute the teardown SQL for this AST. + async fn execute_teardown_sql( + &mut self, + rewritten: &RewrittenAst, + fail_on_err: bool, + ) -> Result<()> { + for sql in &rewritten.extra.native_teardown_sql { + if let Err(err) = self.execute_native_sql_statement(sql).await { + if fail_on_err { + return Err(err); + } else { + debug!(%sql, %err, "Ignoring error from teardown SQL"); + } + } + } + Ok(()) } /// Get a list of transformations that should be applied to the AST before @@ -97,17 +150,17 @@ pub trait Driver: Send + Sync + 'static { let transforms = self.transforms(); if transforms.is_empty() { return Ok(RewrittenAst { - extra_native_sql: vec![], + extra: TransformExtra::default(), ast: Cow::Borrowed(ast), }); } else { let mut rewritten = ast.clone(); - let mut extra_native_sql = vec![]; + let mut extra = TransformExtra::default(); for transform in transforms { - extra_native_sql.extend(transform.transform(&mut rewritten)?); + extra.extend(transform.transform(&mut rewritten)?); } Ok(RewrittenAst { - extra_native_sql, + extra, ast: Cow::Owned(rewritten), }) } @@ -127,7 +180,7 @@ pub trait Driver: Send + Sync + 'static { pub struct RewrittenAst<'a> { /// Extra native SQL statements to execute before the AST. Probably /// temporary UDFs and things like that. - pub extra_native_sql: Vec, + pub extra: TransformExtra, /// The new AST. pub ast: Cow<'a, ast::SqlProgram>, diff --git a/src/drivers/snowflake/mod.rs b/src/drivers/snowflake/mod.rs index b2ebf9e..a32cf1f 100644 --- a/src/drivers/snowflake/mod.rs +++ b/src/drivers/snowflake/mod.rs @@ -11,7 +11,7 @@ use snowflake_api::{QueryResult, SnowflakeApi}; use tracing::{debug, instrument}; use crate::{ - ast::{self, Emit, Target}, + ast::Target, errors::{format_err, Context, Error, Result}, transforms::{self, Transform, Udf}, }; @@ -223,23 +223,8 @@ impl Driver for SnowflakeDriver { Ok(()) } - async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> { - let rewritten = self.rewrite_ast(ast)?; - for sql in rewritten.extra_native_sql { - self.execute_native_sql_statement(&sql).await?; - } - - // We can only execute one statement at a time. - for statement in rewritten.ast.statements.node_iter() { - let sql = statement.emit_to_string(self.target()); - self.execute_native_sql_statement(&sql).await?; - } - - // Reset session to drop `TEMP` tables and UDFs. - self.connection - .close_session() - .await - .context("could not end Snowflake session") + fn supports_multiple_statements(&self) -> bool { + false } fn transforms(&self) -> Vec> { diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs index d6f609b..f0a76d6 100644 --- a/src/drivers/trino/mod.rs +++ b/src/drivers/trino/mod.rs @@ -9,7 +9,7 @@ use regex::Regex; use tracing::debug; use crate::{ - ast::{self, Emit, Target}, + ast::Target, drivers::sqlite3::SQLite3String, errors::{format_err, Context, Error, Result}, transforms::{self, Transform, Udf}, @@ -24,6 +24,7 @@ pub const TRINO_LOCATOR_PREFIX: &str = "trino:"; // this for simple renaming. static FUNCTION_NAMES: phf::Map<&'static str, &'static str> = phf::phf_map! { "ARRAY_LENGTH" => "CARDINALITY", + "GENERATE_UUID" => "UUID", }; /// A `phf_map!` of BigQuery function names to UDFs. @@ -146,19 +147,8 @@ impl Driver for TrinoDriver { Ok(()) } - #[tracing::instrument(skip_all)] - async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> { - let rewritten = self.rewrite_ast(ast)?; - for sql in rewritten.extra_native_sql { - self.execute_native_sql_statement(&sql).await?; - } - - // We can only execute one statement at a time. - for statement in rewritten.ast.statements.node_iter() { - let sql = statement.emit_to_string(self.target()); - self.execute_native_sql_statement(&sql).await?; - } - Ok(()) + fn supports_multiple_statements(&self) -> bool { + false } fn transforms(&self) -> Vec> { @@ -169,6 +159,11 @@ impl Driver for TrinoDriver { &UDFS, &format_udf, )), + Box::new(transforms::CleanUpTempManually { + format_name: &|table_name| { + SQLite3Ident(&table_name.unescaped_bigquery()).to_string() + }, + }), ] } @@ -254,6 +249,42 @@ impl DriverImpl for TrinoDriver { } } +/// Quote `s` for Trino, surrounding it with `'` and escaping special +/// characters as needed. +fn trino_quote_fmt(s: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if s.chars().all(|c| c.is_ascii_graphic() || c == ' ') { + write!(f, "'")?; + for c in s.chars() { + match c { + '\'' => write!(f, "''")?, + _ => write!(f, "{}", c)?, + } + } + write!(f, "'") + } else { + write!(f, "U&'")?; + for c in s.chars() { + match c { + '\'' => write!(f, "''")?, + '\\' => write!(f, "\\\\")?, + _ if c.is_ascii_graphic() || c == ' ' => write!(f, "{}", c)?, + _ if c as u32 <= 0xFFFF => write!(f, "\\{:04x}", c as u32)?, + _ => write!(f, "\\+{:06x}", c as u32)?, + } + } + write!(f, "'") + } +} + +/// Formatting wrapper for strings quoted with single quotes. +pub struct TrinoString<'a>(pub &'a str); + +impl fmt::Display for TrinoString<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + trino_quote_fmt(self.0, f) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/transforms/clean_up_temp_manually.rs b/src/transforms/clean_up_temp_manually.rs new file mode 100644 index 0000000..1e65bc4 --- /dev/null +++ b/src/transforms/clean_up_temp_manually.rs @@ -0,0 +1,38 @@ +use crate::{ast, errors::Result}; + +use super::{Transform, TransformExtra}; + +/// Transform `OR REPLACE` to the equivalent `DROP IF EXISTS`. +pub struct CleanUpTempManually { + /// Format a table or view name. + pub format_name: &'static dyn Fn(&ast::TableName) -> String, +} + +impl Transform for CleanUpTempManually { + fn transform(self: Box, sql_program: &mut ast::SqlProgram) -> Result { + let mut native_teardown_sql = vec![]; + + #[allow(clippy::single_match)] + for statement in sql_program.statements.node_iter_mut() { + match statement { + ast::Statement::CreateTable(ast::CreateTableStatement { + temporary: temporary @ Some(_), + table_name, + .. + }) => { + *temporary = None; + native_teardown_sql.push(format!( + "DROP TABLE IF EXISTS {}", + (self.format_name)(table_name) + )); + } + _ => {} + } + } + + Ok(TransformExtra { + native_setup_sql: vec![], + native_teardown_sql, + }) + } +} diff --git a/src/transforms/mod.rs b/src/transforms/mod.rs index fadc932..6f65385 100644 --- a/src/transforms/mod.rs +++ b/src/transforms/mod.rs @@ -8,10 +8,12 @@ use crate::{ast, errors::Result}; pub use self::{ + clean_up_temp_manually::CleanUpTempManually, or_replace_to_drop_if_exists::OrReplaceToDropIfExists, rename_functions::{RenameFunctions, Udf}, }; +mod clean_up_temp_manually; mod or_replace_to_drop_if_exists; mod rename_functions; @@ -27,5 +29,24 @@ pub trait Transform { /// A transform should only be used once, as it may modify itself in the /// process of transforming the AST. To enforce this, the transform takes /// `self: Box` rather than `&mut self`. - fn transform(self: Box, sql_program: &mut ast::SqlProgram) -> Result>; + fn transform(self: Box, sql_program: &mut ast::SqlProgram) -> Result; +} + +/// Extra SQL returned by a [`Transform`]. +#[derive(Debug, Default)] +pub struct TransformExtra { + /// Individual statements that should be run before the transformed program. + pub native_setup_sql: Vec, + + /// Individual statements that should be run after the transformed program, + /// even if it fails. These may individually fail. + pub native_teardown_sql: Vec, +} + +impl TransformExtra { + /// Merge in another `TransformExtra`. + pub fn extend(&mut self, other: TransformExtra) { + self.native_setup_sql.extend(other.native_setup_sql); + self.native_teardown_sql.extend(other.native_teardown_sql); + } } diff --git a/src/transforms/or_replace_to_drop_if_exists.rs b/src/transforms/or_replace_to_drop_if_exists.rs index 49532de..14b6d57 100644 --- a/src/transforms/or_replace_to_drop_if_exists.rs +++ b/src/transforms/or_replace_to_drop_if_exists.rs @@ -5,13 +5,13 @@ use crate::{ errors::Result, }; -use super::Transform; +use super::{Transform, TransformExtra}; /// Transform `OR REPLACE` to the equivalent `DROP IF EXISTS`. pub struct OrReplaceToDropIfExists; impl Transform for OrReplaceToDropIfExists { - fn transform(self: Box, sql_program: &mut ast::SqlProgram) -> Result> { + fn transform(self: Box, sql_program: &mut ast::SqlProgram) -> Result { let old_statements = sql_program.statements.take(); for mut node_or_sep in old_statements { match &mut node_or_sep { @@ -63,7 +63,7 @@ impl Transform for OrReplaceToDropIfExists { } sql_program.statements.push_node_or_sep(node_or_sep); } - Ok(vec![]) + Ok(TransformExtra::default()) } } diff --git a/src/transforms/rename_functions.rs b/src/transforms/rename_functions.rs index a2c65df..923b51c 100644 --- a/src/transforms/rename_functions.rs +++ b/src/transforms/rename_functions.rs @@ -9,7 +9,7 @@ use crate::{ errors::Result, }; -use super::Transform; +use super::{Transform, TransformExtra}; /// A Snowflake UDF (user-defined function). pub struct Udf { @@ -70,7 +70,7 @@ impl RenameFunctions { } impl Transform for RenameFunctions { - fn transform(mut self: Box, sql_program: &mut ast::SqlProgram) -> Result> { + fn transform(mut self: Box, sql_program: &mut ast::SqlProgram) -> Result { // Walk the AST, renaming functions and collecting UDFs. sql_program.drive_mut(self.as_mut()); @@ -79,6 +79,9 @@ impl Transform for RenameFunctions { for udf in self.udfs.values() { extra_sql.push((self.format_udf)(udf)); } - Ok(extra_sql) + Ok(TransformExtra { + native_setup_sql: extra_sql, + native_teardown_sql: vec![], + }) } } diff --git a/tests/sql/data_types/literal_scalars.sql b/tests/sql/data_types/literal_scalars.sql index 16105b3..a5bd9ea 100644 --- a/tests/sql/data_types/literal_scalars.sql +++ b/tests/sql/data_types/literal_scalars.sql @@ -1,10 +1,13 @@ --- pending: snowflake Test harness Arrow library reads 1.5 as 15 CREATE OR REPLACE TABLE __result1 AS SELECT - NULL AS n, + -- Trino doesn't like untyped NULLs. + CAST(NULL AS INT64) AS n, 1 AS i, - 1.5 AS f, + -- TODO: Both our Trino and our Snowflake drivers are unable to correctly + -- read columns with types inferred from float literals. We may want to add + -- a transformation that automatically casts float literals to FLOAT64. + CAST(1.5 AS FLOAT64) AS f, 'Hello, world!' AS s, '\a\b\f\n\r\t\v\\\?\'\"\`\101\x41\X41\u0041\U00000041' AS escapes, r'\a' AS raw1, diff --git a/tests/sql/functions/aggregate/approx_quantiles.sql b/tests/sql/functions/aggregate/approx_quantiles.sql index 377abea..bb99752 100644 --- a/tests/sql/functions/aggregate/approx_quantiles.sql +++ b/tests/sql/functions/aggregate/approx_quantiles.sql @@ -1,5 +1,6 @@ -- pending: snowflake Use APPROX_PERCENTILE instead of APPROX_QUANTILES (complicated) -- pending: sqlite3 No APPROX_QUANTILES function +-- pending: trino Use APPROX_PERCENTILE instead of APPROX_QUANTILES (complicated) CREATE TEMP TABLE quantile_data (x INT64); INSERT INTO quantile_data VALUES (1), (2), (3), (4), (5); diff --git a/tests/sql/functions/simple/farm_fingerprint.sql b/tests/sql/functions/simple/farm_fingerprint.sql index c3d195a..12bf39c 100644 --- a/tests/sql/functions/simple/farm_fingerprint.sql +++ b/tests/sql/functions/simple/farm_fingerprint.sql @@ -1,5 +1,6 @@ -- pending: snowflake FARM_FINGERPRINT only exists on BigQuery -- pending: sqlite3 FARM_FINGERPRINT only exists on BigQuery +-- pending: trino FARM_FINGERPRINT only exists on BigQuery CREATE OR REPLACE TABLE __result1 AS SELECT FARM_FINGERPRINT('foo') AS str_farm, diff --git a/tests/sql/pending/date_functions.sql b/tests/sql/pending/date_functions.sql index 09b76ac..d55956e 100644 --- a/tests/sql/pending/date_functions.sql +++ b/tests/sql/pending/date_functions.sql @@ -1,2 +1,3 @@ -- pending: snowflake Lots of work but low risk. --- pending: sqlite3 Lots of work but low risk. \ No newline at end of file +-- pending: sqlite3 Lots of work but low risk. +-- pending: sqlite3 Lots of work but probably low risk. \ No newline at end of file diff --git a/tests/sql/pending/structs.sql b/tests/sql/pending/structs.sql index bb3ef48..9b7e10e 100644 --- a/tests/sql/pending/structs.sql +++ b/tests/sql/pending/structs.sql @@ -1,2 +1,3 @@ -- pending: snowflake Need to emulate using OBJECT. May be challenging. -- pending: sqlite3 Need to build structs from scratch. +-- pending: trino Can probably use ROW \ No newline at end of file