Skip to content

Commit

Permalink
Continue testing string functions
Browse files Browse the repository at this point in the history
We implement several string-related functions. This required adding
another pass to the transpiler, so we could rewrite the AST.
  • Loading branch information
emk committed Oct 16, 2023
1 parent d1fb002 commit f637139
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 99 deletions.
179 changes: 101 additions & 78 deletions src/ast.rs

Large diffs are not rendered by default.

36 changes: 30 additions & 6 deletions src/drivers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! Database drivers.
use std::{collections::VecDeque, fmt, str::FromStr};
use std::{borrow::Cow, collections::VecDeque, fmt, str::FromStr};

use async_trait::async_trait;

use crate::{
ast::{Emit, Target},
ast::{self, Emit, Target},
errors::{format_err, Error, Result},
};

Expand Down Expand Up @@ -67,12 +67,26 @@ pub trait Driver: Send + Sync + 'static {

/// Execute a query represented as an AST. This can execute multiple
/// statements.
async fn execute_ast(&mut self, ast: &crate::ast::SqlProgram) -> Result<()> {
for statement in &ast.statements {
let sql = statement.emit_to_string(self.target());
async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> {
let rewritten = self.rewrite_ast(ast)?;
for sql in rewritten.extra_native_sql {
self.execute_native_sql_statement(&sql).await?;
}
Ok(())
let sql = rewritten.ast.emit_to_string(self.target());
self.execute_native_sql_statement(&sql).await
}

/// Rewrite an AST to convert function names, etc., into versions that can
/// be passed to [`Emitted::emit_to_string`] for this database. This allows
/// us to do less database-specific work in [`Emit::emit`], and more in the
/// database drivers themselves. This can't change lexical syntax, but it
/// can change the structure of the AST.
fn rewrite_ast<'ast>(&self, ast: &'ast ast::SqlProgram) -> Result<RewrittenAst<'ast>> {
// Default implementation does nothing.
Ok(RewrittenAst {
extra_native_sql: vec![],
ast: Cow::Borrowed(ast),
})
}

/// Drop a table if it exists.
Expand All @@ -85,6 +99,16 @@ pub trait Driver: Send + Sync + 'static {
async fn compare_tables(&mut self, result_table: &str, expected_table: &str) -> Result<()>;
}

/// The output of [`Driver::rewrite_ast`].
pub struct RewrittenAst<'a> {
/// Extra native SQL statements to execute before the AST. Probably
/// temporary UDFs and things like that.
pub extra_native_sql: Vec<String>,

/// The new AST.
pub ast: Cow<'a, ast::SqlProgram>,
}

/// Extensions to [`Driver`] that are not ["object safe"][safe].
///
/// [safe]: https://doc.rust-lang.org/reference/items/traits.html#object-safety
Expand Down
46 changes: 42 additions & 4 deletions src/drivers/snowflake/mod.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
//! A Snowflake driver.
use std::{env, fmt, str::FromStr};
use std::{borrow::Cow, env, fmt, str::FromStr};

use arrow_json::writer::record_batches_to_json_rows;
use async_trait::async_trait;
use derive_visitor::DriveMut;
use once_cell::sync::Lazy;
use regex::Regex;
use serde_json::Value;
use snowflake_api::{QueryResult, SnowflakeApi};
use tracing::instrument;
use tracing::{debug, instrument};

use crate::{
ast::Target,
ast::{self, Emit, Target},
errors::{format_err, Context, Error, Result},
};

use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator};
use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator, RewrittenAst};

mod rename_functions;

/// Locator prefix for Snowflake.
pub const SNOWFLAKE_LOCATOR_PREFIX: &str = "snowflake:";
Expand Down Expand Up @@ -190,13 +193,48 @@ impl Driver for SnowflakeDriver {

#[instrument(skip(self, sql), err)]
async fn execute_native_sql_statement(&mut self, sql: &str) -> Result<()> {
debug!(%sql, "executing SQL");
self.connection
.exec(sql)
.await
.with_context(|| format!("error running SQL: {}", sql))?;
Ok(())
}

async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> {
let rewritten = self.rewrite_ast(ast)?;
for sql in rewritten.extra_native_sql {
self.execute_native_sql_statement(&sql).await?;
}

// We can only execute one statement at a time.
for statement in &rewritten.ast.statements {
let sql = statement.emit_to_string(self.target());
self.execute_native_sql_statement(&sql).await?;
}

// Reset session to drop `TEMP` tables and UDFs.
self.connection
.close_session()
.await
.context("could not end Snowflake session")
}

fn rewrite_ast<'ast>(&self, ast: &'ast ast::SqlProgram) -> Result<RewrittenAst<'ast>> {
let mut ast = ast.clone();
let mut renamer = rename_functions::RenameFunctions::default();
ast.drive_mut(&mut renamer);
let extra_native_sql = renamer
.udfs
.values()
.map(|udf| udf.to_sql())
.collect::<Vec<_>>();
Ok(RewrittenAst {
extra_native_sql,
ast: Cow::Owned(ast),
})
}

#[instrument(skip(self))]
async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> {
self.execute_native_sql_statement(&format!(
Expand Down
65 changes: 65 additions & 0 deletions src/drivers/snowflake/rename_functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//! A simple tree-walker that renames functions to their Snowflake equivalents.
use std::collections::HashMap;

use derive_visitor::VisitorMut;

use crate::ast::{FunctionName, Identifier};

// A `phf_map!` of BigQuery function names to Snowflake function names. Use
// this for simple renaming.
static FUNCTION_NAMES: phf::Map<&'static str, &'static str> = phf::phf_map! {
"REGEXP_EXTRACT" => "REGEXP_SUBSTR",
"SHA256" => "SHA2_BINARY", // Second argument defaults to SHA256.
};

/// A Snowflake UDF (user-defined function).
pub struct Udf {
pub decl: &'static str,
pub sql: &'static str,
}

impl Udf {
/// Generate the SQL to create this UDF.
pub fn to_sql(&self) -> String {
format!(
"CREATE OR REPLACE TEMP FUNCTION {} AS $$\n{}\n$$\n",
self.decl, self.sql
)
}
}

/// A `phf_map!` of BigQuery UDF names to Snowflake UDFs. Use this when we
/// actually need to create a UDF as a helper function.
static UDFS: phf::Map<&'static str, &'static Udf> = phf::phf_map! {
"TO_HEX" => &Udf { decl: "TO_HEX(b BINARY) RETURNS STRING", sql: "HEX_ENCODE(b, 0)" },
};

#[derive(Default, VisitorMut)]
#[visitor(FunctionName(enter))]
pub struct RenameFunctions {
// UDFs that we need to create, if we haven't already.
pub udfs: HashMap<String, &'static Udf>,
}

impl RenameFunctions {
fn enter_function_name(&mut self, function_name: &mut FunctionName) {
if let FunctionName::Function { function } = function_name {
let name = function.unescaped_bigquery().to_ascii_uppercase();
if let Some(snowflake_name) = FUNCTION_NAMES.get(&name) {
// Rename the function.
let orig_ident = function_name.function_identifier();
*function_name = FunctionName::Function {
function: Identifier {
token: orig_ident.token.with_token_str(snowflake_name),
text: snowflake_name.to_string(),
},
};
} else if let Some(udf) = UDFS.get(&name) {
// We'll need a UDF, so add it to our list it if isn't already
// there.
self.udfs.insert(name, udf);
}
}
}
}
40 changes: 39 additions & 1 deletion src/drivers/sqlite3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use std::{fmt, str::FromStr, vec};

use async_rusqlite::Connection;
use async_trait::async_trait;
use rusqlite::{functions::FunctionFlags, types};
use rusqlite::{
functions::{self, FunctionFlags},
types,
};

use crate::{
ast::Target,
Expand Down Expand Up @@ -102,6 +105,10 @@ impl SQLite3Driver {
// yet, but it allows us to parse and execute queries that use `UNNEST`.
register_unnest(conn).expect("failed to register UNNEST");

// Install real functions, where we have them.
conn.create_scalar_function("concat", -1, FunctionFlags::SQLITE_UTF8, func_concat)
.expect("failed to create concat function");

// Install some dummy functions that always return NULL.
let dummy_fns = &[
("array", -1),
Expand Down Expand Up @@ -159,6 +166,37 @@ impl Driver for SQLite3Driver {
}
}

/// Concatenate a list of values into a string. Mimics BigQuery's `CONCAT`.
fn func_concat(ctx: &functions::Context<'_>) -> rusqlite::Result<types::Value> {
let mut result = String::new();
for idx in 0..ctx.len() {
match ctx.get_raw(idx) {
types::ValueRef::Null => return Ok(types::Value::Null),
types::ValueRef::Integer(i) => {
result.push_str(&i.to_string());
}
types::ValueRef::Real(r) => {
result.push_str(&r.to_string());
}
types::ValueRef::Text(s) => {
let s = String::from_utf8(s.to_owned()).map_err(|_| {
rusqlite::Error::InvalidFunctionParameterType(idx, types::Type::Text)
})?;
result.push_str(&s);
}
types::ValueRef::Blob(_) => {
// `CONCAT` should also support being called with _all_ `BINARY`
// values, but you can't mix them with other types.
return Err(rusqlite::Error::InvalidFunctionParameterType(
idx,
types::Type::Blob,
));
}
}
}
Ok(types::Value::Text(result))
}

#[async_trait]
impl DriverImpl for SQLite3Driver {
type Type = String;
Expand Down
18 changes: 9 additions & 9 deletions tests/sql/functions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ Here is a list of functions that are high priorities to implement. You can
generate your own version of this list by running `joinery parse
--count-function-calls queries.csv`.

- [ ] REGEXP_REPLACE(_,_,_)
- [ ] REGEXP_EXTRACT(_,_)
- [ ] COALESCE(*)
- [ ] LOWER(_)
- [ ] TO_HEX(_)
- [ ] SHA256(_)
- [x] REGEXP_REPLACE(_,_,_)
- [x] REGEXP_EXTRACT(_,_)
- [x] COALESCE(*)
- [x] LOWER(_)
- [x] TO_HEX(_)
- [x] SHA256(_)
- [ ] LENGTH(_)
- [ ] CONCAT(*)
- [ ] TRIM(_)
- [x] CONCAT(*)
- [x] TRIM(_)
- [ ] ARRAY_TO_STRING(_,_)
- [ ] SUM(_)
- [ ] FARM_FINGERPRINT(_)
- [ ] ANY_VALUE(_)
- [ ] ROW_NUMBER() OVER(..)
- [ ] COUNTIF(_)
- [ ] UPPER(_)
- [x] UPPER(_)
- [ ] ARRAY_AGG(_)
- [ ] DATE_TRUNC(_,_) (special)
- [ ] MIN(_)
Expand Down
24 changes: 24 additions & 0 deletions tests/sql/functions/simple/coalesce.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
-- COALESCE

CREATE OR REPLACE TABLE __result1 AS
SELECT
-- This is not supported on BigQuery.
--
-- COALESCE() AS coalesce_empty,
--
-- This is supported by BigQuery, but not by other databases. We can remove
-- it in the transpiler if absolutely necessary, but it's probably better to
-- make it an error.
--
-- COALESCE(1) AS coalesce_one,
COALESCE(1, 2) AS coalesce_two,
COALESCE(NULL, 2) AS coalesce_two_null,
COALESCE(NULL, 2, 3) AS coalesce_three_null;

CREATE OR REPLACE TABLE __expected1 (
coalesce_two INT64,
coalesce_two_null INT64,
coalesce_three_null INT64,
);
INSERT INTO __expected1 VALUES
(1, 2, 2);
1 change: 0 additions & 1 deletion tests/sql/functions/simple/regexp.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
-- pending: snowflake REGEX_EXTRACT needs to be wrapped with REGEXP_SUBSTR
-- pending: sqlite3 No regex fuctions
--
-- REGEXP_REPLACE, REGEXP_EXTRACT
Expand Down
13 changes: 13 additions & 0 deletions tests/sql/functions/simple/sha256.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- pending: sqlite3 No SHA256 function

CREATE OR REPLACE TABLE __result1 AS
SELECT
to_hex(sha256('hi')) AS hi,
to_hex(sha256(null)) AS null_arg;

CREATE OR REPLACE TABLE __expected1 (
hi STRING,
null_arg STRING,
);
INSERT INTO __expected1 VALUES
('8f434346648f6b96df89dda901c5176b10a6d83961dd3c1ac88b59b2dc327aa4', NULL);
21 changes: 21 additions & 0 deletions tests/sql/functions/simple/strings.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- LOWER, UPPER, TRIM, CONCAT
CREATE OR REPLACE TABLE __result1 AS
SELECT
LOWER('FOO') AS lower,
UPPER('foo') AS upper,
TRIM(' foo ') AS trim,
CONCAT('foo', 'bar') AS concat,
CONCAT('x', 1) concat_casted,
CONCAT('x', NULL) concat_null;

CREATE OR REPLACE TABLE __expected1 (
lower STRING,
upper STRING,
trim STRING,
concat STRING,
concat_casted STRING,
concat_null STRING,
);
INSERT INTO __expected1 VALUES
('foo', 'FOO', 'foo', 'foobar', 'x1', NULL);

0 comments on commit f637139

Please sign in to comment.