Skip to content

Commit

Permalink
switch from prusto to dbcrossbar_trino client
Browse files Browse the repository at this point in the history
This required a bunch of changes including:
- Using BIGINT instead of INT in our UDFs
- Adding handling for empty arrays to infer their types as `ARRAY<INT64>`
- Emitting empty array types as CAST
- Got rid of sqllite driver and related transformations

Co-Authored-By: Eric Kidd <[email protected]>
  • Loading branch information
hanakslr and emk committed Dec 4, 2024
1 parent ac8831a commit 2600412
Show file tree
Hide file tree
Showing 16 changed files with 828 additions and 1,231 deletions.
1,012 changes: 594 additions & 418 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@ async-trait = "0.1.73"
clap = { version = "4.4.6", features = ["derive", "wrap_help"] }
codespan-reporting = "0.11.1"
csv = "1.2.2"
dbcrossbar_trino = { version = "0.2.2", features = ["macros", "values", "client", "rustls-tls"] }
derive-visitor = "0.4.0"
glob = "0.3.1"
joinery_macros = { path = "joinery_macros" }
once_cell = "1.18.0"
owo-colors = "4.0.0-rc.1"
peg = "0.8.2"
phf = { version = "0.11.2", features = ["macros"] }
# Waiting on https://github.com/nooberfsh/prusto/issues/33
prusto = { git = "https://github.com/nooberfsh/prusto.git" }
rand = "0.8.4"
regex = "1.10.0"
rusqlite = { version = "0.29.0", features = ["bundled", "functions", "vtab"] }
Expand Down
4 changes: 2 additions & 2 deletions sql/trino_compat.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ RETURN (
);

-- Handle odd BigQuery behaviour around the `pos` argument.
CREATE OR REPLACE FUNCTION memory.joinery_compat.SUBSTR_COMPAT(input VARCHAR, pos INT)
CREATE OR REPLACE FUNCTION memory.joinery_compat.SUBSTR_COMPAT(input VARCHAR, pos BIGINT)
RETURNS VARCHAR
RETURNS NULL ON NULL INPUT
RETURN (
Expand All @@ -50,7 +50,7 @@ RETURN (

-- As the two argument case, but also treat a length greater than the string as
-- as "to the end of the string".
CREATE OR REPLACE FUNCTION memory.joinery_compat.SUBSTR_COMPAT(input VARCHAR, pos INT, len INT)
CREATE OR REPLACE FUNCTION memory.joinery_compat.SUBSTR_COMPAT(input VARCHAR, pos BIGINT, len BIGINT)
RETURNS VARCHAR
RETURNS NULL ON NULL INPUT
RETURN (
Expand Down
119 changes: 46 additions & 73 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use joinery_macros::{Emit, EmitDefault, Spanned, ToTokens};
use crate::{
drivers::{
bigquery::{BigQueryName, BigQueryString},
sqlite3::KEYWORDS as SQLITE3_KEYWORDS,
trino::{TrinoString, KEYWORDS as TRINO_KEYWORDS},
},
errors::{format_err, Error, Result},
Expand All @@ -39,8 +38,8 @@ use crate::{
tokenize_sql, EmptyFile, Ident, Keyword, Literal, LiteralValue, PseudoKeyword, Punct,
RawToken, Span, Spanned, ToTokens, Token, TokenStream, TokenWriter,
},
types::{StructType, TableType},
util::{is_c_ident, AnsiIdent, AnsiString},
types::{StructType, TableType, ValueType},
util::{is_c_ident, AnsiIdent},
};

/// None of these keywords should ever be matched as a bare Ident. We use
Expand All @@ -67,7 +66,6 @@ static KEYWORDS: phf::Set<&'static str> = phf::phf_set! {
#[allow(dead_code)]
pub enum Target {
BigQuery,
SQLite3,
Trino,
}

Expand All @@ -76,7 +74,6 @@ impl Target {
pub fn is_keyword(self, s: &str) -> bool {
let keywords = match self {
Target::BigQuery => &KEYWORDS,
Target::SQLite3 => &SQLITE3_KEYWORDS,
Target::Trino => &TRINO_KEYWORDS,
};
keywords.contains(s.to_ascii_uppercase().as_str())
Expand All @@ -87,7 +84,6 @@ impl fmt::Display for Target {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Target::BigQuery => write!(f, "bigquery"),
Target::SQLite3 => write!(f, "sqlite3"),
Target::Trino => write!(f, "trino"),
}
}
Expand Down Expand Up @@ -191,7 +187,7 @@ impl Emit for Ident {
if t.is_keyword(&self.name) || !is_c_ident(&self.name) {
match t {
Target::BigQuery => write!(f, "{}", BigQueryName(&self.name))?,
Target::SQLite3 | Target::Trino => {
Target::Trino => {
write!(f, "{}", AnsiIdent(&self.name))?;
}
}
Expand Down Expand Up @@ -236,7 +232,6 @@ impl Emit for LiteralValue {
LiteralValue::Float64(fl) => write!(f, "{}", fl),
LiteralValue::String(s) => match t {
Target::BigQuery => write!(f, "{}", BigQueryString(s)),
Target::SQLite3 => write!(f, "{}", AnsiString(s)),
Target::Trino => write!(f, "{}", TrinoString(s)),
},
}
Expand Down Expand Up @@ -319,6 +314,11 @@ impl<T: Node> NodeVec<T> {
}
}

/// Is this [`NodeVec`] empty of real items?
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}

/// Take the elements from this [`NodeVec`], leaving it empty, and return a new
/// [`NodeVec`] containing the taken elements.
pub fn take(&mut self) -> NodeVec<T> {
Expand Down Expand Up @@ -681,7 +681,7 @@ pub struct CommonTableExpression {
}

/// Set operators.
#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)]
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub enum SetOperator {
UnionAll {
union_token: Keyword,
Expand All @@ -701,32 +701,6 @@ pub enum SetOperator {
},
}

impl Emit for SetOperator {
fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> {
match t {
// SQLite3 only supports `UNION` and `INTERSECT`. We'll keep the
// whitespace from the first token in those cases. In other cases,
// we'll substitute `UNION` with a comment saying what it really
// should be.
Target::SQLite3 => match self {
SetOperator::UnionAll {
union_token,
all_token,
} => {
union_token.emit(t, f)?;
all_token.emit(t, f)
}
SetOperator::UnionDistinct { union_token, .. } => union_token.emit(t, f),
SetOperator::IntersectDistinct {
intersect_token, ..
} => intersect_token.emit(t, f),
SetOperator::ExceptDistinct { except_token, .. } => except_token.emit(t, f),
},
_ => self.emit_default(t, f),
}
}
}

/// A `SELECT` expression.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct SelectExpression {
Expand Down Expand Up @@ -943,11 +917,6 @@ impl Emit for CastType {
CastType::SafeCast { safe_cast_token } if t == Target::Trino => {
safe_cast_token.ident.token.with_str("TRY_CAST").emit(t, f)
}
// TODO: This isn't strictly right, but it's as close as I know how to
// get with SQLite3.
CastType::SafeCast { safe_cast_token } if t == Target::SQLite3 => {
safe_cast_token.ident.token.with_str("CAST").emit(t, f)
}
_ => self.emit_default(t, f),
}
}
Expand Down Expand Up @@ -1088,6 +1057,12 @@ pub struct BinopExpression {
/// a missing `ARRAY` and a `delim1` of `(`. We'll let the parser handle that.
#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)]
pub struct ArrayExpression {
/// Type information added later by inference.
#[emit(skip)]
#[to_tokens(skip)]
#[drive(skip)]
pub ty: Option<ValueType>,

pub array_token: Option<Keyword>,
pub element_type: Option<ArrayElementType>,
pub delim1: Punct,
Expand Down Expand Up @@ -1155,17 +1130,11 @@ impl Emit for ArrayExpression {
last_token.with_ws_only().emit(t, f)?;
}
_ => match t {
Target::SQLite3 => {
if let Some(array_token) = &self.array_token {
array_token.emit(t, f)?;
} else {
f.write_token_start("ARRAY")?;
}
self.delim1.token.with_str("(").emit(t, f)?;
self.definition.emit(t, f)?;
self.delim2.token.with_str(")").emit(t, f)?;
}
Target::Trino => {
let needs_cast = self.definition.has_zero_element_expressions();
if needs_cast {
f.write_token_start("CAST(")?;
}
if let Some(array_token) = &self.array_token {
array_token.emit(t, f)?;
} else {
Expand All @@ -1174,6 +1143,18 @@ impl Emit for ArrayExpression {
self.delim1.token.with_str("[").emit(t, f)?;
self.definition.emit(t, f)?;
self.delim2.token.with_str("]").emit(t, f)?;

if needs_cast {
f.write_token_start(" AS ")?;
let ty = self
.ty
.as_ref()
.expect("type should have been added by type checker");
let data_type = DataType::try_from(ty.clone())
.expect("should be able to print data type of ARRAY");
data_type.emit(t, f)?;
f.write_token_start(")")?;
}
}
_ => self.emit_default(t, f)?,
},
Expand Down Expand Up @@ -1208,6 +1189,16 @@ pub enum ArrayDefinition {
Elements(NodeVec<Expression>),
}

impl ArrayDefinition {
/// Is this an empty array expression?
pub fn has_zero_element_expressions(&self) -> bool {
match self {
ArrayDefinition::Query { .. } => false,
ArrayDefinition::Elements(elements) => elements.items.is_empty(),
}
}
}

/// A very restricted version of [`SelectExpression`] that we allow in an
/// `ARRAY(SELECT ...)`, because we handle this as a special case to avoid
/// hitting correlated subquery restrictions in other databases.
Expand Down Expand Up @@ -1258,7 +1249,7 @@ impl Emit for StructExpression {
}
self.paren2.emit(t, f)?;
f.write_token_start("AS")?;
let ty = self
let ty: &StructType = self
.ty
.as_ref()
.expect("type should have been added by type checker");
Expand Down Expand Up @@ -1546,27 +1537,6 @@ pub enum DataType {
impl Emit for DataType {
fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> {
match t {
Target::SQLite3 => match self {
DataType::Bool(token) | DataType::Int64(token) => {
token.ident.token.with_str("INTEGER").emit(t, f)
}
// NUMERIC is used when people want accurate math, so we want
// either BLOB or TEXT, whatever makes math easier.
DataType::Bytes(token) | DataType::Numeric(token) => {
token.ident.token.with_str("BLOB").emit(t, f)
}
DataType::Float64(token) => token.ident.token.with_str("REAL").emit(t, f),
DataType::String(token)
| DataType::Date(token) // All date types should be strings
| DataType::Datetime(token)
| DataType::Geography(token) // Use GeoJSON
| DataType::Time(token)
| DataType::Timestamp(token) =>
token.ident.token.with_str("TEXT").emit(t, f),
DataType::Array { array_token: token, .. } | DataType::Struct { struct_token: token, .. } => {
token.ident.token.with_str("/*JSON*/TEXT").emit(t, f)
}
},
Target::Trino => match self {
DataType::Bool(token) => token.ident.token.with_str("BOOLEAN").emit(t, f),
DataType::Bytes(token) => token.ident.token.with_str("VARBINARY").emit(t, f),
Expand Down Expand Up @@ -2416,6 +2386,7 @@ peg::parser! {
rule array_expression() -> ArrayExpression
= delim1:p("[") definition:array_definition() delim2:p("]") {
ArrayExpression {
ty: None,
array_token: None,
element_type: None,
delim1,
Expand All @@ -2426,6 +2397,7 @@ peg::parser! {
/ array_token:k("ARRAY") element_type:array_element_type()?
delim1:p("[") definition:array_definition() delim2:p("]") {
ArrayExpression {
ty: None,
array_token: Some(array_token),
element_type,
delim1,
Expand All @@ -2436,6 +2408,7 @@ peg::parser! {
/ array_token:k("ARRAY") element_type:array_element_type()?
delim1:p("(") definition:array_definition() delim2:p(")") {
ArrayExpression {
ty: None,
array_token: Some(array_token),
element_type,
delim1,
Expand Down Expand Up @@ -3068,7 +3041,7 @@ mod tests {
use super::*;

#[tokio::test]
async fn test_parser_and_run_with_sqlite3() {
async fn test_parser() {
let sql_examples = &[
// Basic test cases of gradually increasing complexity.
(r#"SELECT * FROM t"#, None),
Expand Down
22 changes: 14 additions & 8 deletions src/drivers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::{borrow::Cow, collections::VecDeque, fmt, str::FromStr};

use async_trait::async_trait;
use dbcrossbar_trino::values::IsCloseEnoughTo;
use tracing::{debug, trace};

use crate::{
Expand All @@ -13,13 +14,9 @@ use crate::{
transforms::{Transform, TransformExtra},
};

use self::{
sqlite3::{SQLite3Locator, SQLITE3_LOCATOR_PREFIX},
trino::{TrinoLocator, TRINO_LOCATOR_PREFIX},
};
use self::trino::{TrinoLocator, TRINO_LOCATOR_PREFIX};

pub mod bigquery;
pub mod sqlite3;
pub mod trino;

/// A URL-like locator for a database.
Expand All @@ -41,7 +38,6 @@ impl FromStr for Box<dyn Locator> {
.ok_or_else(|| format_err!("could not find scheme for locator: {}", s))?;
let prefix = &s[..colon_pos + 1];
match prefix {
SQLITE3_LOCATOR_PREFIX => Ok(Box::new(s.parse::<SQLite3Locator>()?)),
TRINO_LOCATOR_PREFIX => Ok(Box::new(s.parse::<TrinoLocator>()?)),
_ => Err(format_err!("unsupported database type: {}", s)),
}
Expand All @@ -53,6 +49,14 @@ impl FromStr for Box<dyn Locator> {
pub trait Comparable: fmt::Debug + fmt::Display + PartialEq + Send + Sync {}
impl<T> Comparable for T where T: fmt::Debug + fmt::Display + PartialEq + Send + Sync {}

/// A type that supports basic approximate equality and display, used for
/// comparing test results. This is typically used for values written to
/// databases, which might involve small changes to float values and timestamp
/// precision.
#[allow(dead_code)]
pub trait ApproxComparable: fmt::Debug + fmt::Display + IsCloseEnoughTo + Send + Sync {}
impl<T> ApproxComparable for T where T: fmt::Debug + fmt::Display + IsCloseEnoughTo + Send + Sync {}

/// A database driver. This is mostly used for running SQL tests.
///
/// This trait is ["object safe"][safe], so it can be used as `Box<dyn Driver>`
Expand Down Expand Up @@ -204,7 +208,7 @@ pub trait DriverImpl {
type Type: Comparable;

/// A native value for this database.
type Value: Comparable;
type Value: ApproxComparable;

/// An iterator over the rows of a table.
type Rows: Iterator<Item = Result<Vec<Self::Value>>> + Send + Sync;
Expand Down Expand Up @@ -259,14 +263,16 @@ pub trait DriverImpl {
loop {
match (result_rows.next(), expected_rows.next()) {
(Some(Ok(result_row)), Some(Ok(expected_row))) => {
if result_row != expected_row {
if !result_row.is_close_enough_to(&expected_row) {
// Build a short diff of recent rows.
let mut diff = String::new();
for row in result_history.iter() {
diff.push_str(&format!(" {}\n", vec_to_string(row)));
}
diff.push_str(&format!("- {}\n", vec_to_string(&expected_row)));
diff.push_str(&format!("+ {}\n", vec_to_string(&result_row)));
debug!("expected row: {:?}", expected_row);
debug!("result row: {:?}", result_row);

return Err(Error::tables_not_equal(format!(
"row from {} does not match row from {}:\n{}",
Expand Down
Loading

0 comments on commit 2600412

Please sign in to comment.