From 239390531eef7bca263818ac19a0f1859f0aa0d1 Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Thu, 19 Oct 2023 17:55:33 -0400 Subject: [PATCH] Rename Sqlite3 escapers to Ansi This was how we were already using them. --- src/ast.rs | 24 ++++++------------ src/drivers/snowflake/mod.rs | 7 +++--- src/drivers/sqlite3/mod.rs | 47 +++++------------------------------- src/drivers/trino/mod.rs | 20 +++++++-------- src/util.rs | 35 +++++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 71 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index a2413c3..9f6bf48 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -32,14 +32,9 @@ use derive_visitor::{Drive, DriveMut}; use joinery_macros::{Emit, EmitDefault}; use crate::{ - drivers::{ - bigquery::BigQueryName, - snowflake::SnowflakeString, - sqlite3::{SQLite3Ident, SQLite3String}, - trino::TrinoString, - }, + drivers::{bigquery::BigQueryName, snowflake::SnowflakeString, trino::TrinoString}, errors::{Result, SourceError}, - util::is_c_ident, + util::{is_c_ident, AnsiIdent, AnsiString}, }; /// None of these keywords should ever be matched as a bare identifier. We use @@ -552,12 +547,9 @@ impl Emit for Identifier { // Snowflake and SQLite3 use double quoted identifiers and // escape quotes by doubling them. Neither allows backslash // escapes here, though Snowflake does in strings. - Target::Snowflake | Target::SQLite3 | Target::Trino => write!( - f, - "{}{}", - SQLite3Ident(&self.text), - t.f(&self.token.ws_only()) - ), + Target::Snowflake | Target::SQLite3 | Target::Trino => { + write!(f, "{}{}", AnsiIdent(&self.text), t.f(&self.token.ws_only())) + } } } else { write!(f, "{}{}", self.text, t.f(&self.token.ws_only())) @@ -613,7 +605,7 @@ impl Emit for TableName { | TableName::DatasetTable { table, .. } | TableName::Table { table } => table.token.ws_only(), }; - write!(f, "{}{}", SQLite3Ident(&name), t.f(&ws)) + write!(f, "{}{}", AnsiIdent(&name), t.f(&ws)) } _ => self.emit_default(t, f), } @@ -992,7 +984,7 @@ impl Emit for Expression { token, value: LiteralValue::String(s), } if t == Target::SQLite3 => { - SQLite3String(s).fmt(f)?; + AnsiString(s).fmt(f)?; token.ws_only().emit(t, f) } // SQLite3 quotes strings differently. @@ -1396,7 +1388,7 @@ impl Emit for FunctionName { Target::SQLite3 => { let name = self.unescaped_bigquery(); let ws = self.function_identifier().token.ws_only(); - write!(f, "{}{}", SQLite3Ident(&name), t.f(&ws)) + write!(f, "{}{}", AnsiIdent(&name), t.f(&ws)) } _ => self.emit_default(t, f), } diff --git a/src/drivers/snowflake/mod.rs b/src/drivers/snowflake/mod.rs index a32cf1f..826de85 100644 --- a/src/drivers/snowflake/mod.rs +++ b/src/drivers/snowflake/mod.rs @@ -14,9 +14,10 @@ use crate::{ ast::Target, errors::{format_err, Context, Error, Result}, transforms::{self, Transform, Udf}, + util::AnsiIdent, }; -use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator}; +use super::{Column, Driver, DriverImpl, Locator}; /// Locator prefix for Snowflake. pub const SNOWFLAKE_LOCATOR_PREFIX: &str = "snowflake:"; @@ -239,7 +240,7 @@ impl Driver for SnowflakeDriver { async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> { self.execute_native_sql_statement(&format!( "DROP TABLE IF EXISTS {}", - SQLite3Ident(table_name) + AnsiIdent(table_name) )) .await } @@ -291,7 +292,7 @@ impl DriverImpl for SnowflakeDriver { ) -> Result { let column_list = columns .iter() - .map(|c| SQLite3Ident(&c.name).to_string()) + .map(|c| AnsiIdent(&c.name).to_string()) .collect::>() .join(", "); // TODO: Again, quoting the table name fails. diff --git a/src/drivers/sqlite3/mod.rs b/src/drivers/sqlite3/mod.rs index 29fb35c..8f4e4ea 100644 --- a/src/drivers/sqlite3/mod.rs +++ b/src/drivers/sqlite3/mod.rs @@ -12,6 +12,7 @@ use rusqlite::{ use crate::{ ast::Target, errors::{format_err, Context, Error, Result}, + util::{AnsiIdent, AnsiString}, }; use self::unnest::register_unnest; @@ -157,7 +158,7 @@ impl Driver for SQLite3Driver { } async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> { - let sql = format!("DROP TABLE IF EXISTS {}", SQLite3Ident(table_name)); + let sql = format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name)); self.execute_native_sql_statement(&sql).await } @@ -204,7 +205,7 @@ impl DriverImpl for SQLite3Driver { type Rows = Box>> + Send + Sync>; async fn table_columns(&mut self, table_name: &str) -> Result>> { - let sql = format!("PRAGMA table_info({})", SQLite3Ident(table_name)); + let sql = format!("PRAGMA table_info({})", AnsiIdent(table_name)); self.conn .call(move |conn| { let mut stmt = conn.prepare(&sql).context("failed to prepare SQL")?; @@ -229,13 +230,13 @@ impl DriverImpl for SQLite3Driver { ) -> Result { let column_list = columns .iter() - .map(|c| sqlite3_quote_ident(&c.name)) + .map(|c| AnsiIdent(&c.name).to_string()) .collect::>() .join(", "); let sql = format!( "SELECT {} FROM {} ORDER BY {}", column_list, - SQLite3Ident(table_name), + AnsiIdent(table_name), column_list ); let columns = columns.to_vec(); @@ -275,44 +276,8 @@ impl std::fmt::Display for Value { rusqlite::types::Value::Null => write!(f, "NULL"), rusqlite::types::Value::Integer(i) => write!(f, "{}", i), rusqlite::types::Value::Real(r) => write!(f, "{}", r), - rusqlite::types::Value::Text(s) => write!(f, "{}", SQLite3String(s)), + rusqlite::types::Value::Text(s) => write!(f, "{}", AnsiString(s)), rusqlite::types::Value::Blob(b) => write!(f, "{:?}", b), } } } - -/// Escape an identifier for use in a SQLite3 query. -fn sqlite3_quote_ident(s: &str) -> String { - format!("{}", SQLite3Ident(s)) -} - -/// Format a single- or double-quoted string for use in a SQLite3 query. SQLite3 -/// does not support backslash escapes. -fn sqlite3_quote_fmt(s: &str, quote_char: char, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", quote_char)?; - for c in s.chars() { - if c == quote_char { - write!(f, "{}", quote_char)?; - } - write!(f, "{}", c)?; - } - write!(f, "{}", quote_char) -} - -/// Formatting wrapper for single-quoted strings. -pub struct SQLite3String<'a>(pub &'a str); - -impl fmt::Display for SQLite3String<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - sqlite3_quote_fmt(self.0, '\'', f) - } -} - -/// Formatting wrapper for double-quoted identifiers. -pub struct SQLite3Ident<'a>(pub &'a str); - -impl fmt::Display for SQLite3Ident<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - sqlite3_quote_fmt(self.0, '"', f) - } -} diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs index 7851d34..4074d15 100644 --- a/src/drivers/trino/mod.rs +++ b/src/drivers/trino/mod.rs @@ -10,12 +10,12 @@ use tracing::debug; use crate::{ ast::Target, - drivers::sqlite3::SQLite3String, errors::{format_err, Context, Error, Result}, transforms::{self, Transform, Udf}, + util::AnsiIdent, }; -use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator}; +use super::{Column, Driver, DriverImpl, Locator}; /// Our locator prefix. pub const TRINO_LOCATOR_PREFIX: &str = "trino:"; @@ -162,9 +162,7 @@ impl Driver for TrinoDriver { &format_udf, )), Box::new(transforms::CleanUpTempManually { - format_name: &|table_name| { - SQLite3Ident(&table_name.unescaped_bigquery()).to_string() - }, + format_name: &|table_name| AnsiIdent(&table_name.unescaped_bigquery()).to_string(), }), ] } @@ -172,7 +170,7 @@ impl Driver for TrinoDriver { #[tracing::instrument(skip(self))] async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> { self.client - .execute(format!("DROP TABLE IF EXISTS {}", SQLite3Ident(table_name))) + .execute(format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name))) .await .map_err(abbreviate_trino_error) .with_context(|| format!("Failed to drop table: {}", table_name))?; @@ -205,9 +203,9 @@ impl DriverImpl for TrinoDriver { FROM information_schema.columns WHERE table_catalog = {} AND table_schema = {} AND table_name = {}", // TODO: Replace with real string escapes. - SQLite3String(&self.catalog), - SQLite3String(&self.schema), - SQLite3String(table_name) + TrinoString(&self.catalog), + TrinoString(&self.schema), + TrinoString(table_name) ); Ok(self .client @@ -232,13 +230,13 @@ impl DriverImpl for TrinoDriver { ) -> Result { let cols_sql = columns .iter() - .map(|c| SQLite3Ident(&c.name).to_string()) + .map(|c| AnsiIdent(&c.name).to_string()) .collect::>() .join(", "); let sql = format!( "SELECT {} FROM {} ORDER BY {}", cols_sql, - SQLite3Ident(table_name), + AnsiIdent(table_name), cols_sql ); let rows = self diff --git a/src/util.rs b/src/util.rs index 4d5e327..34bd777 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,5 +1,7 @@ //! The inevitable grab-bag of utility functions. +use std::fmt; + /// Is `s` a valid C identifier? pub fn is_c_ident(s: &str) -> bool { let mut chars = s.chars(); @@ -9,3 +11,36 @@ pub fn is_c_ident(s: &str) -> bool { _ => chars.all(|c| c.is_ascii_alphanumeric() || c == '_'), } } + +/// Format a single- or double-quoted string for use in a SQLite3 query. SQLite3 +/// does not support backslash escapes. +fn ansi_quote_fmt(s: &str, quote_char: char, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", quote_char)?; + for c in s.chars() { + if c == quote_char { + write!(f, "{}", quote_char)?; + } + write!(f, "{}", c)?; + } + write!(f, "{}", quote_char) +} + +/// Formatting wrapper for single-quoted strings. It's actually pretty rare for +/// databases to support just this format with no backslash escapes, so please +/// double-check things like `'\\'`` and `'\''` before using this. +pub struct AnsiString<'a>(pub &'a str); + +impl fmt::Display for AnsiString<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ansi_quote_fmt(self.0, '\'', f) + } +} + +/// Formatting wrapper for double-quoted identifiers. +pub struct AnsiIdent<'a>(pub &'a str); + +impl fmt::Display for AnsiIdent<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ansi_quote_fmt(self.0, '"', f) + } +}