diff --git a/Cargo.lock b/Cargo.lock index 54bc26c..258a466 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -459,9 +459,9 @@ dependencies = [ [[package]] name = "dbcrossbar_trino" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c824a64db21129755b9ae5eb528dbe48c5eae4cf1be90b885e4ed873ddd7ea4" +checksum = "c3dae8d902e5840541698408431acda4e6f1a0cc938d335807b876cc5d529785" dependencies = [ "base64", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 25e6c89..97e727a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ async-trait = "0.1.73" clap = { version = "4.4.6", features = ["derive", "wrap_help"] } codespan-reporting = "0.11.1" csv = "1.2.2" -dbcrossbar_trino = { version = "0.2.3", features = [ +dbcrossbar_trino = { version = "0.2.4", features = [ "macros", "values", "client", diff --git a/src/ast.rs b/src/ast.rs index cd61692..cd6025e 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -855,6 +855,8 @@ pub enum Expression { FunctionCall(FunctionCall), Index(IndexExpression), FieldAccess(FieldAccessExpression), + Load(LoadExpression), + Store(StoreExpression), } impl Expression { @@ -1631,6 +1633,44 @@ pub struct FieldAccessExpression { pub field_name: Ident, } +/// A "load" expression, which transforms an SQL value from a "storage" type (eg +/// "VARCHAR") to a "memory" type (eg "UUID"). Used for databases like Trino, +/// where the storage types for a given connector may be more limited than the +/// standard Trino memory types. +/// +/// These are not found in the original parsed AST, but are added while +/// transforming the AST. +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] +pub struct LoadExpression { + /// Inferred memory type. + #[emit(skip)] + #[to_tokens(skip)] + #[drive(skip)] + memory_type: Option, + + /// Our underlying expression. + pub expression: Box, +} + +/// A "store" expression, which transforms an SQL value from a "memory" type +/// (eg "UUID") to a "storage" type (eg "VARCHAR"). Used for databases like +/// Trino, where the storage types for a given connector may be more limited +/// than the standard Trino memory types. +/// +/// These are not found in the original parsed AST, but are added while +/// transforming the AST. +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] +pub struct StoreExpression { + /// Inferred memory type. + #[emit(skip)] + #[to_tokens(skip)] + #[drive(skip)] + memory_type: Option, + + /// Our underlying expression. + pub expression: Box, +} + /// An `AS` alias. #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] pub struct Alias { diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs index 57c6953..34132e3 100644 --- a/src/drivers/trino/mod.rs +++ b/src/drivers/trino/mod.rs @@ -5,7 +5,8 @@ use std::{fmt, str::FromStr, time::Duration}; use async_trait::async_trait; use dbcrossbar_trino::{ client::{Client, ClientBuilder, ClientError, QueryError}, - ConnectorType as TrinoConnectorType, DataType, Ident as TrinoIdent, Value, + ConnectorType as TrinoConnectorType, DataType as TrinoDataType, Field as TrinoField, + Ident as TrinoIdent, Value, }; use joinery_macros::sql_quote; use once_cell::sync::Lazy; @@ -18,6 +19,7 @@ use crate::{ errors::{format_err, Context, Error, Result}, tokenizer::TokenStream, transforms::{self, RenameFunctionsBuilder, Transform}, + types::{SimpleType, StructElementType, ValueType}, util::AnsiIdent, }; @@ -269,7 +271,7 @@ impl Driver for TrinoDriver { #[async_trait] impl DriverImpl for TrinoDriver { - type Type = DataType; + type Type = TrinoDataType; type Value = Value; type Rows = Box>> + Send + Sync>; @@ -288,7 +290,7 @@ impl DriverImpl for TrinoDriver { .map(|c| { Ok(Column { name: c.column_name.as_unquoted_str().to_owned(), - ty: DataType::try_from(c.data_type).map_err(Error::other)?, + ty: TrinoDataType::try_from(c.data_type).map_err(Error::other)?, }) }) .collect::>>() @@ -424,6 +426,76 @@ fn should_retry(e: &ClientError) -> bool { // format_err!("Trino error: {}", msg) // } +impl TryFrom<&'_ ValueType> for TrinoDataType { + type Error = Error; + + fn try_from(value_type: &ValueType) -> std::result::Result { + match value_type { + ValueType::Simple(simple_type) => TrinoDataType::try_from(simple_type), + ValueType::Array(simple_type) => Ok(TrinoDataType::Array(Box::new( + TrinoDataType::try_from(simple_type)?, + ))), + } + } +} + +impl TryFrom<&'_ SimpleType> for TrinoDataType { + type Error = Error; + + fn try_from(simple_type: &SimpleType) -> std::result::Result { + match simple_type { + SimpleType::Bool => Ok(TrinoDataType::Boolean), + SimpleType::Bytes => Ok(TrinoDataType::Varbinary), + SimpleType::Date => Ok(TrinoDataType::Date), + SimpleType::Datetime => Ok(TrinoDataType::timestamp()), + SimpleType::Float64 => Ok(TrinoDataType::Double), + SimpleType::Geography => Ok(TrinoDataType::SphericalGeography), + SimpleType::Int64 => Ok(TrinoDataType::BigInt), + SimpleType::Numeric => Ok(TrinoDataType::bigquery_sized_decimal()), + SimpleType::String => Ok(TrinoDataType::varchar()), + SimpleType::Time => Ok(TrinoDataType::time()), + SimpleType::Timestamp => Ok(TrinoDataType::timestamp_with_time_zone()), + SimpleType::Struct(struct_type) => { + let fields = struct_type + .fields + .iter() + .map(|f| TrinoField::try_from(f)) + .collect::>>()?; + Ok(TrinoDataType::Row(fields)) + } + + // These shouldn't make it this far, either because they're not a concrete type, + // or because they're types we only allow as constant arguments to functions that + // should have been transformed away by now. + SimpleType::Bottom | SimpleType::Datepart | SimpleType::Interval | SimpleType::Null => { + Err(format_err!( + "Cannot represent {} as concrete Trino type", + simple_type + )) + } + + // This shouldn't be a constructable value for + // `SimpleType`. + SimpleType::Parameter(_) => unreachable!("parameter types should be resolved"), + } + } +} + +impl TryFrom<&'_ StructElementType> for TrinoField { + type Error = Error; + + fn try_from(field: &StructElementType) -> std::result::Result { + let name = match &field.name { + Some(name) => Some(TrinoIdent::new(&name.name).map_err(Error::other)?), + None => None, + }; + Ok(TrinoField { + name, + data_type: TrinoDataType::try_from(&field.ty)?, + }) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/infer/contains_aggregate.rs b/src/infer/contains_aggregate.rs index ea319e9..65076eb 100644 --- a/src/infer/contains_aggregate.rs +++ b/src/infer/contains_aggregate.rs @@ -91,6 +91,8 @@ impl ContainsAggregate for ast::Expression { // Putting an aggregate here would be very weird. Do not allow it // until forced to do so. ast::Expression::FieldAccess(_) => false, + ast::Expression::Load(load_expr) => load_expr.contains_aggregate(scope), + ast::Expression::Store(store_expr) => store_expr.contains_aggregate(scope), } } } @@ -266,3 +268,15 @@ impl ContainsAggregate for ast::IndexOffset { } } } + +impl ContainsAggregate for ast::LoadExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.expression.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::StoreExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.expression.contains_aggregate(scope) + } +} diff --git a/src/infer/mod.rs b/src/infer/mod.rs index 190dba7..e0a8831 100644 --- a/src/infer/mod.rs +++ b/src/infer/mod.rs @@ -739,6 +739,8 @@ impl InferTypes for ast::Expression { ast::Expression::FunctionCall(fcall) => fcall.infer_types(scope), ast::Expression::Index(index) => index.infer_types(scope), ast::Expression::FieldAccess(field_access) => field_access.infer_types(scope), + ast::Expression::Load(load_expr) => load_expr.infer_types(scope), + ast::Expression::Store(store_expr) => store_expr.infer_types(scope), } } } @@ -1380,6 +1382,26 @@ impl InferTypes for ast::FieldAccessExpression { } } +impl InferTypes for ast::LoadExpression { + type Scope = ColumnSetScope; + type Output = ArgumentType; + + fn infer_types(&mut self, scope: &Self::Scope) -> Result { + // TODO: More here. + self.expression.infer_types(scope) + } +} + +impl InferTypes for ast::StoreExpression { + type Scope = ColumnSetScope; + type Output = ArgumentType; + + fn infer_types(&mut self, scope: &Self::Scope) -> Result { + // TODO: More here. + self.expression.infer_types(scope) + } +} + /// Figure out whether an expression defines an implicit column name. pub trait InferColumnName { /// Infer the column name, if any.