Skip to content

Commit

Permalink
trino: Work around TEMP, fix many other things
Browse files Browse the repository at this point in the history
  • Loading branch information
emk committed Oct 19, 2023
1 parent 79e418b commit 0e92a3a
Show file tree
Hide file tree
Showing 13 changed files with 213 additions and 56 deletions.
21 changes: 20 additions & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
bigquery::BigQueryName,
snowflake::SnowflakeString,
sqlite3::{SQLite3Ident, SQLite3String},
trino::TrinoString,
},
errors::{Result, SourceError},
util::is_c_ident,
Expand Down Expand Up @@ -422,6 +423,14 @@ impl<T: Node> NodeVec<T> {
})
}

/// Iterate over just the nodes in this [`NodeVec`], mutably.
pub fn node_iter_mut(&mut self) -> impl Iterator<Item = &mut T> {
self.items.iter_mut().filter_map(|item| match item {
NodeOrSep::Node(node) => Some(node),
NodeOrSep::Sep(_) => None,
})
}

/// Iterate over nodes and separators separately. Used internally for
/// parsing dotted names.
fn into_node_and_sep_iters(self) -> (impl Iterator<Item = T>, impl Iterator<Item = Token>) {
Expand Down Expand Up @@ -986,6 +995,14 @@ impl Emit for Expression {
SQLite3String(s).fmt(f)?;
token.ws_only().emit(t, f)
}
// SQLite3 quotes strings differently.
Expression::Literal {
token,
value: LiteralValue::String(s),
} if t == Target::Trino => {
TrinoString(s).fmt(f)?;
token.ws_only().emit(t, f)
}
Expression::If {
if_token,
condition,
Expand Down Expand Up @@ -1097,7 +1114,9 @@ pub enum CastType {
impl Emit for CastType {
fn emit(&self, t: Target, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CastType::SafeCast { safe_cast_token } if t == Target::Snowflake => {
CastType::SafeCast { safe_cast_token }
if t == Target::Snowflake || t == Target::Trino =>
{
safe_cast_token.with_token_str("TRY_CAST").emit(t, f)
}
// TODO: This isn't strictly right, but it's as close as I know how to
Expand Down
77 changes: 65 additions & 12 deletions src/drivers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
use std::{borrow::Cow, collections::VecDeque, fmt, str::FromStr};

use async_trait::async_trait;
use tracing::debug;

use crate::{
ast::{self, Emit, Target},
errors::{format_err, Error, Result},
transforms::Transform,
transforms::{Transform, TransformExtra},
};

use self::{
Expand Down Expand Up @@ -66,20 +67,72 @@ pub trait Driver: Send + Sync + 'static {
fn target(&self) -> Target;

/// Execute a single SQL statement, using native SQL for this database. This
/// is only guaranteed to work if passed a single statement, although some
/// drivers may support multiple statements. Resources created using `CREATE
/// TEMP TABLE`, etc., may not persist across calls.
/// is only guaranteed to work if passed a single statement, unless
/// [`Driver::supports_multiple_statements`] returns `true`. Resources
/// created using `CREATE TEMP TABLE`, etc., may not persist across calls.
async fn execute_native_sql_statement(&mut self, sql: &str) -> Result<()>;

/// Does this driver support multiple statements in a single call to
/// [`execute_native_sql_statement`]?
fn supports_multiple_statements(&self) -> bool {
false
}

/// Execute a query represented as an AST. This can execute multiple
/// statements.
async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> {
let rewritten = self.rewrite_ast(ast)?;
for sql in rewritten.extra_native_sql {
self.execute_native_sql_statement(&sql).await?;
self.execute_setup_sql(&rewritten).await?;
let result = if self.supports_multiple_statements() {
self.execute_ast_together(&rewritten).await
} else {
self.execute_ast_separately(&rewritten).await
};
self.execute_teardown_sql(&rewritten, result.is_ok())
.await?;
result
}

/// Execute the setup SQL for this AST.
async fn execute_setup_sql(&mut self, rewritten: &RewrittenAst) -> Result<()> {
for sql in &rewritten.extra.native_setup_sql {
self.execute_native_sql_statement(sql).await?;
}
Ok(())
}

/// Execute the AST as a single SQL string.
async fn execute_ast_together(&mut self, rewritten: &RewrittenAst) -> Result<()> {
let sql = rewritten.ast.emit_to_string(self.target());
self.execute_native_sql_statement(&sql).await
self.execute_native_sql_statement(&sql).await?;
Ok(())
}

/// Execute the AST as individual SQL statements.
async fn execute_ast_separately(&mut self, rewritten: &RewrittenAst) -> Result<()> {
for statement in rewritten.ast.statements.node_iter() {
let sql = statement.emit_to_string(self.target());
self.execute_native_sql_statement(&sql).await?;
}
Ok(())
}

/// Execute the teardown SQL for this AST.
async fn execute_teardown_sql(
&mut self,
rewritten: &RewrittenAst,
fail_on_err: bool,
) -> Result<()> {
for sql in &rewritten.extra.native_teardown_sql {
if let Err(err) = self.execute_native_sql_statement(sql).await {
if fail_on_err {
return Err(err);
} else {
debug!(%sql, %err, "Ignoring error from teardown SQL");
}
}
}
Ok(())
}

/// Get a list of transformations that should be applied to the AST before
Expand All @@ -97,17 +150,17 @@ pub trait Driver: Send + Sync + 'static {
let transforms = self.transforms();
if transforms.is_empty() {
return Ok(RewrittenAst {
extra_native_sql: vec![],
extra: TransformExtra::default(),
ast: Cow::Borrowed(ast),
});
} else {
let mut rewritten = ast.clone();
let mut extra_native_sql = vec![];
let mut extra = TransformExtra::default();
for transform in transforms {
extra_native_sql.extend(transform.transform(&mut rewritten)?);
extra.extend(transform.transform(&mut rewritten)?);
}
Ok(RewrittenAst {
extra_native_sql,
extra,
ast: Cow::Owned(rewritten),
})
}
Expand All @@ -127,7 +180,7 @@ pub trait Driver: Send + Sync + 'static {
pub struct RewrittenAst<'a> {
/// Extra native SQL statements to execute before the AST. Probably
/// temporary UDFs and things like that.
pub extra_native_sql: Vec<String>,
pub extra: TransformExtra,

/// The new AST.
pub ast: Cow<'a, ast::SqlProgram>,
Expand Down
21 changes: 3 additions & 18 deletions src/drivers/snowflake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use snowflake_api::{QueryResult, SnowflakeApi};
use tracing::{debug, instrument};

use crate::{
ast::{self, Emit, Target},
ast::Target,
errors::{format_err, Context, Error, Result},
transforms::{self, Transform, Udf},
};
Expand Down Expand Up @@ -223,23 +223,8 @@ impl Driver for SnowflakeDriver {
Ok(())
}

async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> {
let rewritten = self.rewrite_ast(ast)?;
for sql in rewritten.extra_native_sql {
self.execute_native_sql_statement(&sql).await?;
}

// We can only execute one statement at a time.
for statement in rewritten.ast.statements.node_iter() {
let sql = statement.emit_to_string(self.target());
self.execute_native_sql_statement(&sql).await?;
}

// Reset session to drop `TEMP` tables and UDFs.
self.connection
.close_session()
.await
.context("could not end Snowflake session")
fn supports_multiple_statements(&self) -> bool {
false
}

fn transforms(&self) -> Vec<Box<dyn Transform>> {
Expand Down
59 changes: 45 additions & 14 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use regex::Regex;
use tracing::debug;

use crate::{
ast::{self, Emit, Target},
ast::Target,
drivers::sqlite3::SQLite3String,
errors::{format_err, Context, Error, Result},
transforms::{self, Transform, Udf},
Expand All @@ -24,6 +24,7 @@ pub const TRINO_LOCATOR_PREFIX: &str = "trino:";
// this for simple renaming.
static FUNCTION_NAMES: phf::Map<&'static str, &'static str> = phf::phf_map! {
"ARRAY_LENGTH" => "CARDINALITY",
"GENERATE_UUID" => "UUID",
};

/// A `phf_map!` of BigQuery function names to UDFs.
Expand Down Expand Up @@ -146,19 +147,8 @@ impl Driver for TrinoDriver {
Ok(())
}

#[tracing::instrument(skip_all)]
async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> {
let rewritten = self.rewrite_ast(ast)?;
for sql in rewritten.extra_native_sql {
self.execute_native_sql_statement(&sql).await?;
}

// We can only execute one statement at a time.
for statement in rewritten.ast.statements.node_iter() {
let sql = statement.emit_to_string(self.target());
self.execute_native_sql_statement(&sql).await?;
}
Ok(())
fn supports_multiple_statements(&self) -> bool {
false
}

fn transforms(&self) -> Vec<Box<dyn Transform>> {
Expand All @@ -169,6 +159,11 @@ impl Driver for TrinoDriver {
&UDFS,
&format_udf,
)),
Box::new(transforms::CleanUpTempManually {
format_name: &|table_name| {
SQLite3Ident(&table_name.unescaped_bigquery()).to_string()
},
}),
]
}

Expand Down Expand Up @@ -254,6 +249,42 @@ impl DriverImpl for TrinoDriver {
}
}

/// Quote `s` for Trino, surrounding it with `'` and escaping special
/// characters as needed.
fn trino_quote_fmt(s: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if s.chars().all(|c| c.is_ascii_graphic() || c == ' ') {
write!(f, "'")?;
for c in s.chars() {
match c {
'\'' => write!(f, "''")?,
_ => write!(f, "{}", c)?,
}
}
write!(f, "'")
} else {
write!(f, "U&'")?;
for c in s.chars() {
match c {
'\'' => write!(f, "''")?,
'\\' => write!(f, "\\\\")?,
_ if c.is_ascii_graphic() || c == ' ' => write!(f, "{}", c)?,
_ if c as u32 <= 0xFFFF => write!(f, "\\{:04x}", c as u32)?,
_ => write!(f, "\\+{:06x}", c as u32)?,
}
}
write!(f, "'")
}
}

/// Formatting wrapper for strings quoted with single quotes.
pub struct TrinoString<'a>(pub &'a str);

impl fmt::Display for TrinoString<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
trino_quote_fmt(self.0, f)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
38 changes: 38 additions & 0 deletions src/transforms/clean_up_temp_manually.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use crate::{ast, errors::Result};

use super::{Transform, TransformExtra};

/// Transform `OR REPLACE` to the equivalent `DROP IF EXISTS`.
pub struct CleanUpTempManually {
/// Format a table or view name.
pub format_name: &'static dyn Fn(&ast::TableName) -> String,
}

impl Transform for CleanUpTempManually {
fn transform(self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
let mut native_teardown_sql = vec![];

#[allow(clippy::single_match)]
for statement in sql_program.statements.node_iter_mut() {
match statement {
ast::Statement::CreateTable(ast::CreateTableStatement {
temporary: temporary @ Some(_),
table_name,
..
}) => {
*temporary = None;
native_teardown_sql.push(format!(
"DROP TABLE IF EXISTS {}",
(self.format_name)(table_name)
));
}
_ => {}
}
}

Ok(TransformExtra {
native_setup_sql: vec![],
native_teardown_sql,
})
}
}
23 changes: 22 additions & 1 deletion src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
use crate::{ast, errors::Result};

pub use self::{
clean_up_temp_manually::CleanUpTempManually,
or_replace_to_drop_if_exists::OrReplaceToDropIfExists,
rename_functions::{RenameFunctions, Udf},
};

mod clean_up_temp_manually;
mod or_replace_to_drop_if_exists;
mod rename_functions;

Expand All @@ -27,5 +29,24 @@ pub trait Transform {
/// A transform should only be used once, as it may modify itself in the
/// process of transforming the AST. To enforce this, the transform takes
/// `self: Box<Self>` rather than `&mut self`.
fn transform(self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<Vec<String>>;
fn transform(self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra>;
}

/// Extra SQL returned by a [`Transform`].
#[derive(Debug, Default)]
pub struct TransformExtra {
/// Individual statements that should be run before the transformed program.
pub native_setup_sql: Vec<String>,

/// Individual statements that should be run after the transformed program,
/// even if it fails. These may individually fail.
pub native_teardown_sql: Vec<String>,
}

impl TransformExtra {
/// Merge in another `TransformExtra`.
pub fn extend(&mut self, other: TransformExtra) {
self.native_setup_sql.extend(other.native_setup_sql);
self.native_teardown_sql.extend(other.native_teardown_sql);
}
}
Loading

0 comments on commit 0e92a3a

Please sign in to comment.