Skip to content

Commit

Permalink
Add Load and Store scaffolding
Browse files Browse the repository at this point in the history
This code is incomplete, but it should offer basic scaffolding for
Trino storage transformations. We add:

- `Expression::Load` and `Expression::Store`, which don't do anything
   yet.
- Code for converting `ValueType` to `TrinoDataType`.
  • Loading branch information
emk committed Dec 6, 2024
1 parent 625dcbe commit a6e1be5
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 6 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async-trait = "0.1.73"
clap = { version = "4.4.6", features = ["derive", "wrap_help"] }
codespan-reporting = "0.11.1"
csv = "1.2.2"
dbcrossbar_trino = { version = "0.2.3", features = [
dbcrossbar_trino = { version = "0.2.4", features = [
"macros",
"values",
"client",
Expand Down
40 changes: 40 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,8 @@ pub enum Expression {
FunctionCall(FunctionCall),
Index(IndexExpression),
FieldAccess(FieldAccessExpression),
Load(LoadExpression),
Store(StoreExpression),
}

impl Expression {
Expand Down Expand Up @@ -1631,6 +1633,44 @@ pub struct FieldAccessExpression {
pub field_name: Ident,
}

/// A "load" expression, which transforms an SQL value from a "storage" type (eg
/// "VARCHAR") to a "memory" type (eg "UUID"). Used for databases like Trino,
/// where the storage types for a given connector may be more limited than the
/// standard Trino memory types.
///
/// These are not found in the original parsed AST, but are added while
/// transforming the AST.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct LoadExpression {
/// Inferred memory type.
#[emit(skip)]
#[to_tokens(skip)]
#[drive(skip)]
memory_type: Option<ValueType>,

/// Our underlying expression.
pub expression: Box<Expression>,
}

/// A "store" expression, which transforms an SQL value from a "memory" type
/// (eg "UUID") to a "storage" type (eg "VARCHAR"). Used for databases like
/// Trino, where the storage types for a given connector may be more limited
/// than the standard Trino memory types.
///
/// These are not found in the original parsed AST, but are added while
/// transforming the AST.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct StoreExpression {
/// Inferred memory type.
#[emit(skip)]
#[to_tokens(skip)]
#[drive(skip)]
memory_type: Option<ValueType>,

/// Our underlying expression.
pub expression: Box<Expression>,
}

/// An `AS` alias.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct Alias {
Expand Down
78 changes: 75 additions & 3 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::{fmt, str::FromStr, time::Duration};
use async_trait::async_trait;
use dbcrossbar_trino::{
client::{Client, ClientBuilder, ClientError, QueryError},
ConnectorType as TrinoConnectorType, DataType, Ident as TrinoIdent, Value,
ConnectorType as TrinoConnectorType, DataType as TrinoDataType, Field as TrinoField,
Ident as TrinoIdent, Value,
};
use joinery_macros::sql_quote;
use once_cell::sync::Lazy;
Expand All @@ -18,6 +19,7 @@ use crate::{
errors::{format_err, Context, Error, Result},
tokenizer::TokenStream,
transforms::{self, RenameFunctionsBuilder, Transform},
types::{SimpleType, StructElementType, ValueType},
util::AnsiIdent,
};

Expand Down Expand Up @@ -269,7 +271,7 @@ impl Driver for TrinoDriver {

#[async_trait]
impl DriverImpl for TrinoDriver {
type Type = DataType;
type Type = TrinoDataType;
type Value = Value;
type Rows = Box<dyn Iterator<Item = Result<Vec<Self::Value>>> + Send + Sync>;

Expand All @@ -288,7 +290,7 @@ impl DriverImpl for TrinoDriver {
.map(|c| {
Ok(Column {
name: c.column_name.as_unquoted_str().to_owned(),
ty: DataType::try_from(c.data_type).map_err(Error::other)?,
ty: TrinoDataType::try_from(c.data_type).map_err(Error::other)?,
})
})
.collect::<Result<Vec<_>>>()
Expand Down Expand Up @@ -424,6 +426,76 @@ fn should_retry(e: &ClientError) -> bool {
// format_err!("Trino error: {}", msg)
// }

impl TryFrom<&'_ ValueType> for TrinoDataType {
type Error = Error;

fn try_from(value_type: &ValueType) -> std::result::Result<Self, Self::Error> {
match value_type {
ValueType::Simple(simple_type) => TrinoDataType::try_from(simple_type),
ValueType::Array(simple_type) => Ok(TrinoDataType::Array(Box::new(
TrinoDataType::try_from(simple_type)?,
))),
}
}
}

impl TryFrom<&'_ SimpleType> for TrinoDataType {
type Error = Error;

fn try_from(simple_type: &SimpleType) -> std::result::Result<Self, Self::Error> {
match simple_type {
SimpleType::Bool => Ok(TrinoDataType::Boolean),
SimpleType::Bytes => Ok(TrinoDataType::Varbinary),
SimpleType::Date => Ok(TrinoDataType::Date),
SimpleType::Datetime => Ok(TrinoDataType::timestamp()),
SimpleType::Float64 => Ok(TrinoDataType::Double),
SimpleType::Geography => Ok(TrinoDataType::SphericalGeography),
SimpleType::Int64 => Ok(TrinoDataType::BigInt),
SimpleType::Numeric => Ok(TrinoDataType::bigquery_sized_decimal()),
SimpleType::String => Ok(TrinoDataType::varchar()),
SimpleType::Time => Ok(TrinoDataType::time()),
SimpleType::Timestamp => Ok(TrinoDataType::timestamp_with_time_zone()),
SimpleType::Struct(struct_type) => {
let fields = struct_type
.fields
.iter()
.map(|f| TrinoField::try_from(f))
.collect::<Result<Vec<_>>>()?;
Ok(TrinoDataType::Row(fields))
}

// These shouldn't make it this far, either because they're not a concrete type,
// or because they're types we only allow as constant arguments to functions that
// should have been transformed away by now.
SimpleType::Bottom | SimpleType::Datepart | SimpleType::Interval | SimpleType::Null => {
Err(format_err!(
"Cannot represent {} as concrete Trino type",
simple_type
))
}

// This shouldn't be a constructable value for
// `SimpleType<ResolvedTypeVarsOnly>`.
SimpleType::Parameter(_) => unreachable!("parameter types should be resolved"),
}
}
}

impl TryFrom<&'_ StructElementType> for TrinoField {
type Error = Error;

fn try_from(field: &StructElementType) -> std::result::Result<Self, Self::Error> {
let name = match &field.name {
Some(name) => Some(TrinoIdent::new(&name.name).map_err(Error::other)?),
None => None,
};
Ok(TrinoField {
name,
data_type: TrinoDataType::try_from(&field.ty)?,
})
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
14 changes: 14 additions & 0 deletions src/infer/contains_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ impl ContainsAggregate for ast::Expression {
// Putting an aggregate here would be very weird. Do not allow it
// until forced to do so.
ast::Expression::FieldAccess(_) => false,
ast::Expression::Load(load_expr) => load_expr.contains_aggregate(scope),
ast::Expression::Store(store_expr) => store_expr.contains_aggregate(scope),
}
}
}
Expand Down Expand Up @@ -266,3 +268,15 @@ impl ContainsAggregate for ast::IndexOffset {
}
}
}

impl ContainsAggregate for ast::LoadExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.expression.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::StoreExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.expression.contains_aggregate(scope)
}
}
22 changes: 22 additions & 0 deletions src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,8 @@ impl InferTypes for ast::Expression {
ast::Expression::FunctionCall(fcall) => fcall.infer_types(scope),
ast::Expression::Index(index) => index.infer_types(scope),
ast::Expression::FieldAccess(field_access) => field_access.infer_types(scope),
ast::Expression::Load(load_expr) => load_expr.infer_types(scope),
ast::Expression::Store(store_expr) => store_expr.infer_types(scope),
}
}
}
Expand Down Expand Up @@ -1380,6 +1382,26 @@ impl InferTypes for ast::FieldAccessExpression {
}
}

impl InferTypes for ast::LoadExpression {
type Scope = ColumnSetScope;
type Output = ArgumentType;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
// TODO: More here.
self.expression.infer_types(scope)
}
}

impl InferTypes for ast::StoreExpression {
type Scope = ColumnSetScope;
type Output = ArgumentType;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
// TODO: More here.
self.expression.infer_types(scope)
}
}

/// Figure out whether an expression defines an implicit column name.
pub trait InferColumnName {
/// Infer the column name, if any.
Expand Down

0 comments on commit a6e1be5

Please sign in to comment.