Skip to content

Commit

Permalink
Rename Sqlite3 escapers to Ansi
Browse files Browse the repository at this point in the history
This was how we were already using them.
  • Loading branch information
emk committed Oct 19, 2023
1 parent 1e433ff commit 2393905
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 71 deletions.
24 changes: 8 additions & 16 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,9 @@ use derive_visitor::{Drive, DriveMut};
use joinery_macros::{Emit, EmitDefault};

use crate::{
drivers::{
bigquery::BigQueryName,
snowflake::SnowflakeString,
sqlite3::{SQLite3Ident, SQLite3String},
trino::TrinoString,
},
drivers::{bigquery::BigQueryName, snowflake::SnowflakeString, trino::TrinoString},
errors::{Result, SourceError},
util::is_c_ident,
util::{is_c_ident, AnsiIdent, AnsiString},
};

/// None of these keywords should ever be matched as a bare identifier. We use
Expand Down Expand Up @@ -552,12 +547,9 @@ impl Emit for Identifier {
// Snowflake and SQLite3 use double quoted identifiers and
// escape quotes by doubling them. Neither allows backslash
// escapes here, though Snowflake does in strings.
Target::Snowflake | Target::SQLite3 | Target::Trino => write!(
f,
"{}{}",
SQLite3Ident(&self.text),
t.f(&self.token.ws_only())
),
Target::Snowflake | Target::SQLite3 | Target::Trino => {
write!(f, "{}{}", AnsiIdent(&self.text), t.f(&self.token.ws_only()))
}
}
} else {
write!(f, "{}{}", self.text, t.f(&self.token.ws_only()))
Expand Down Expand Up @@ -613,7 +605,7 @@ impl Emit for TableName {
| TableName::DatasetTable { table, .. }
| TableName::Table { table } => table.token.ws_only(),
};
write!(f, "{}{}", SQLite3Ident(&name), t.f(&ws))
write!(f, "{}{}", AnsiIdent(&name), t.f(&ws))
}
_ => self.emit_default(t, f),
}
Expand Down Expand Up @@ -992,7 +984,7 @@ impl Emit for Expression {
token,
value: LiteralValue::String(s),
} if t == Target::SQLite3 => {
SQLite3String(s).fmt(f)?;
AnsiString(s).fmt(f)?;
token.ws_only().emit(t, f)
}
// SQLite3 quotes strings differently.
Expand Down Expand Up @@ -1396,7 +1388,7 @@ impl Emit for FunctionName {
Target::SQLite3 => {
let name = self.unescaped_bigquery();
let ws = self.function_identifier().token.ws_only();
write!(f, "{}{}", SQLite3Ident(&name), t.f(&ws))
write!(f, "{}{}", AnsiIdent(&name), t.f(&ws))
}
_ => self.emit_default(t, f),
}
Expand Down
7 changes: 4 additions & 3 deletions src/drivers/snowflake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ use crate::{
ast::Target,
errors::{format_err, Context, Error, Result},
transforms::{self, Transform, Udf},
util::AnsiIdent,
};

use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator};
use super::{Column, Driver, DriverImpl, Locator};

/// Locator prefix for Snowflake.
pub const SNOWFLAKE_LOCATOR_PREFIX: &str = "snowflake:";
Expand Down Expand Up @@ -239,7 +240,7 @@ impl Driver for SnowflakeDriver {
async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> {
self.execute_native_sql_statement(&format!(
"DROP TABLE IF EXISTS {}",
SQLite3Ident(table_name)
AnsiIdent(table_name)
))
.await
}
Expand Down Expand Up @@ -291,7 +292,7 @@ impl DriverImpl for SnowflakeDriver {
) -> Result<Self::Rows> {
let column_list = columns
.iter()
.map(|c| SQLite3Ident(&c.name).to_string())
.map(|c| AnsiIdent(&c.name).to_string())
.collect::<Vec<_>>()
.join(", ");
// TODO: Again, quoting the table name fails.
Expand Down
47 changes: 6 additions & 41 deletions src/drivers/sqlite3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use rusqlite::{
use crate::{
ast::Target,
errors::{format_err, Context, Error, Result},
util::{AnsiIdent, AnsiString},
};

use self::unnest::register_unnest;
Expand Down Expand Up @@ -157,7 +158,7 @@ impl Driver for SQLite3Driver {
}

async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> {
let sql = format!("DROP TABLE IF EXISTS {}", SQLite3Ident(table_name));
let sql = format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name));
self.execute_native_sql_statement(&sql).await
}

Expand Down Expand Up @@ -204,7 +205,7 @@ impl DriverImpl for SQLite3Driver {
type Rows = Box<dyn Iterator<Item = Result<Vec<Self::Value>>> + Send + Sync>;

async fn table_columns(&mut self, table_name: &str) -> Result<Vec<Column<Self::Type>>> {
let sql = format!("PRAGMA table_info({})", SQLite3Ident(table_name));
let sql = format!("PRAGMA table_info({})", AnsiIdent(table_name));
self.conn
.call(move |conn| {
let mut stmt = conn.prepare(&sql).context("failed to prepare SQL")?;
Expand All @@ -229,13 +230,13 @@ impl DriverImpl for SQLite3Driver {
) -> Result<Self::Rows> {
let column_list = columns
.iter()
.map(|c| sqlite3_quote_ident(&c.name))
.map(|c| AnsiIdent(&c.name).to_string())
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT {} FROM {} ORDER BY {}",
column_list,
SQLite3Ident(table_name),
AnsiIdent(table_name),
column_list
);
let columns = columns.to_vec();
Expand Down Expand Up @@ -275,44 +276,8 @@ impl std::fmt::Display for Value {
rusqlite::types::Value::Null => write!(f, "NULL"),
rusqlite::types::Value::Integer(i) => write!(f, "{}", i),
rusqlite::types::Value::Real(r) => write!(f, "{}", r),
rusqlite::types::Value::Text(s) => write!(f, "{}", SQLite3String(s)),
rusqlite::types::Value::Text(s) => write!(f, "{}", AnsiString(s)),
rusqlite::types::Value::Blob(b) => write!(f, "{:?}", b),
}
}
}

/// Escape an identifier for use in a SQLite3 query.
fn sqlite3_quote_ident(s: &str) -> String {
format!("{}", SQLite3Ident(s))
}

/// Format a single- or double-quoted string for use in a SQLite3 query. SQLite3
/// does not support backslash escapes.
fn sqlite3_quote_fmt(s: &str, quote_char: char, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", quote_char)?;
for c in s.chars() {
if c == quote_char {
write!(f, "{}", quote_char)?;
}
write!(f, "{}", c)?;
}
write!(f, "{}", quote_char)
}

/// Formatting wrapper for single-quoted strings.
pub struct SQLite3String<'a>(pub &'a str);

impl fmt::Display for SQLite3String<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
sqlite3_quote_fmt(self.0, '\'', f)
}
}

/// Formatting wrapper for double-quoted identifiers.
pub struct SQLite3Ident<'a>(pub &'a str);

impl fmt::Display for SQLite3Ident<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
sqlite3_quote_fmt(self.0, '"', f)
}
}
20 changes: 9 additions & 11 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ use tracing::debug;

use crate::{
ast::Target,
drivers::sqlite3::SQLite3String,
errors::{format_err, Context, Error, Result},
transforms::{self, Transform, Udf},
util::AnsiIdent,
};

use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator};
use super::{Column, Driver, DriverImpl, Locator};

/// Our locator prefix.
pub const TRINO_LOCATOR_PREFIX: &str = "trino:";
Expand Down Expand Up @@ -162,17 +162,15 @@ impl Driver for TrinoDriver {
&format_udf,
)),
Box::new(transforms::CleanUpTempManually {
format_name: &|table_name| {
SQLite3Ident(&table_name.unescaped_bigquery()).to_string()
},
format_name: &|table_name| AnsiIdent(&table_name.unescaped_bigquery()).to_string(),
}),
]
}

#[tracing::instrument(skip(self))]
async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> {
self.client
.execute(format!("DROP TABLE IF EXISTS {}", SQLite3Ident(table_name)))
.execute(format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name)))
.await
.map_err(abbreviate_trino_error)
.with_context(|| format!("Failed to drop table: {}", table_name))?;
Expand Down Expand Up @@ -205,9 +203,9 @@ impl DriverImpl for TrinoDriver {
FROM information_schema.columns
WHERE table_catalog = {} AND table_schema = {} AND table_name = {}",
// TODO: Replace with real string escapes.
SQLite3String(&self.catalog),
SQLite3String(&self.schema),
SQLite3String(table_name)
TrinoString(&self.catalog),
TrinoString(&self.schema),
TrinoString(table_name)
);
Ok(self
.client
Expand All @@ -232,13 +230,13 @@ impl DriverImpl for TrinoDriver {
) -> Result<Self::Rows> {
let cols_sql = columns
.iter()
.map(|c| SQLite3Ident(&c.name).to_string())
.map(|c| AnsiIdent(&c.name).to_string())
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT {} FROM {} ORDER BY {}",
cols_sql,
SQLite3Ident(table_name),
AnsiIdent(table_name),
cols_sql
);
let rows = self
Expand Down
35 changes: 35 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! The inevitable grab-bag of utility functions.
use std::fmt;

/// Is `s` a valid C identifier?
pub fn is_c_ident(s: &str) -> bool {
let mut chars = s.chars();
Expand All @@ -9,3 +11,36 @@ pub fn is_c_ident(s: &str) -> bool {
_ => chars.all(|c| c.is_ascii_alphanumeric() || c == '_'),
}
}

/// Format a single- or double-quoted string for use in a SQLite3 query. SQLite3
/// does not support backslash escapes.
fn ansi_quote_fmt(s: &str, quote_char: char, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", quote_char)?;
for c in s.chars() {
if c == quote_char {
write!(f, "{}", quote_char)?;
}
write!(f, "{}", c)?;
}
write!(f, "{}", quote_char)
}

/// Formatting wrapper for single-quoted strings. It's actually pretty rare for
/// databases to support just this format with no backslash escapes, so please
/// double-check things like `'\\'`` and `'\''` before using this.
pub struct AnsiString<'a>(pub &'a str);

impl fmt::Display for AnsiString<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ansi_quote_fmt(self.0, '\'', f)
}
}

/// Formatting wrapper for double-quoted identifiers.
pub struct AnsiIdent<'a>(pub &'a str);

impl fmt::Display for AnsiIdent<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ansi_quote_fmt(self.0, '"', f)
}
}

0 comments on commit 2393905

Please sign in to comment.