diff --git a/src/ast.rs b/src/ast.rs index 68584f1..8d8eefb 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -837,7 +837,8 @@ pub enum Expression { Literal(Literal), BoolValue(Keyword), Null(Keyword), - Name(Name), + Name(NameExpression), + Store(StoreExpression), Cast(Cast), Is(IsExpression), In(InExpression), @@ -867,8 +868,6 @@ pub enum Expression { FunctionCall(FunctionCall), Index(IndexExpression), FieldAccess(FieldAccessExpression), - Load(LoadExpression), - Store(StoreExpression), } impl Expression { @@ -913,6 +912,111 @@ impl DatePart { } } +/// A "load" expression, which transforms an SQL value from a "storage" type (eg +/// "VARCHAR") to a "memory" type (eg "UUID"). Used for databases like Trino, +/// where the storage types for a given connector may be more limited than the +/// standard Trino memory types. +/// +/// These are not found in the original parsed AST, but are added while +/// transforming the AST. +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)] +pub struct NameExpression { + /// **If** we need to do a load conversion, this will be the inferred memory + /// type. + #[emit(skip)] + #[to_tokens(skip)] + #[drive(skip)] + pub load_to_memory_type: Option, + + /// Our underlying expression. + pub name: Name, +} + +impl Emit for NameExpression { + fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> ::std::io::Result<()> { + match t { + // Target::BigQuery => { + // f.write_token_start("%LOAD(")?; + // self.name.emit(t, f)?; + // f.write_token_start(")") + // } + Target::Trino(connector_type) if self.load_to_memory_type.is_some() => { + let bq_memory_type = self + .load_to_memory_type + .as_ref() + .expect("memory_type should have been filled in by type inference"); + let trino_memory_type = + TrinoDataType::try_from(bq_memory_type).map_err(io::Error::other)?; + let transform = connector_type.storage_transform_for(&trino_memory_type); + let (prefix, suffix) = transform.load_prefix_and_suffix(); + + // Wrapping the expression in our prefix and suffix. + // If the expression was col_name containing '[1,2]' in Trino, + // BQ memory type -> JSON, Trino memory type -> JSON, Trino storage type -> VARCHAR + // The Trino storage type is dependent on what the connector can support. + // In this case, the wrapped version would be JSON_PARSE(col_name) + f.write_token_start(&prefix)?; + self.name.emit(t, f)?; + f.write_token_start(&suffix) + } + _ => self.name.emit(t, f), + } + } +} + +/// A "store" expression, which transforms an SQL value from a "memory" type +/// (eg "UUID") to a "storage" type (eg "VARCHAR"). Used for databases like +/// Trino, where the storage types for a given connector may be more limited +/// than the standard Trino memory types. +/// +/// These are not found in the original parsed AST, but are added while +/// transforming the AST. +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)] +pub struct StoreExpression { + /// Inferred memory type. + #[emit(skip)] + #[to_tokens(skip)] + #[drive(skip)] + pub memory_type: Option, + + /// Our underlying expression. + pub expression: Box, +} + +impl Emit for StoreExpression { + fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> ::std::io::Result<()> { + match t { + Target::BigQuery => { + f.write_token_start("%STORE(")?; + self.expression.emit(t, f)?; + f.write_token_start(")") + } + Target::Trino(connector_type) => { + let bq_memory_type = self + .memory_type + .as_ref() + .expect("memory_type should have been filled in by type inference"); + + // If our bq_memory_type is NULL, we don't need to do any transforms because + // NULL is NULL in both storage and memory types and dbcrossbar_trino doesn't + // support NULL as a memory type. + if let ValueType::Simple(SimpleType::Null) = bq_memory_type { + self.expression.emit(t, f) + } else { + let trino_memory_type = + TrinoDataType::try_from(bq_memory_type).map_err(io::Error::other)?; + let transform = connector_type.storage_transform_for(&trino_memory_type); + let (prefix, suffix) = transform.store_prefix_and_suffix(); + + f.write_token_start(&prefix)?; + self.expression.emit(t, f)?; + f.write_token_start(&suffix) + } + } + } + } +} + /// A cast expression. #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] pub struct Cast { @@ -1645,109 +1749,6 @@ pub struct FieldAccessExpression { pub field_name: Ident, } -/// A "load" expression, which transforms an SQL value from a "storage" type (eg -/// "VARCHAR") to a "memory" type (eg "UUID"). Used for databases like Trino, -/// where the storage types for a given connector may be more limited than the -/// standard Trino memory types. -/// -/// These are not found in the original parsed AST, but are added while -/// transforming the AST. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)] -pub struct LoadExpression { - /// Inferred memory type. - #[emit(skip)] - #[to_tokens(skip)] - #[drive(skip)] - pub memory_type: Option, - - /// Our underlying expression. - pub expression: Box, -} - -impl Emit for LoadExpression { - fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> ::std::io::Result<()> { - match t { - Target::BigQuery => { - f.write_token_start("%LOAD(")?; - self.expression.emit(t, f)?; - f.write_token_start(")") - } - Target::Trino(connector_type) => { - let bq_memory_type = self - .memory_type - .as_ref() - .expect("memory_type should have been filled in by type inference"); - let trino_memory_type = - TrinoDataType::try_from(bq_memory_type).map_err(io::Error::other)?; - let transform = connector_type.storage_transform_for(&trino_memory_type); - let (prefix, suffix) = transform.load_prefix_and_suffix(); - - // Wrapping the expression in our prefix and suffix. - // If the expression was col_name containing '[1,2]' in Trino, - // BQ memory type -> JSON, Trino memory type -> JSON, Trino storage type -> VARCHAR - // The Trino storage type is dependent on what the connector can support. - // In this case, the wrapped version would be JSON_PARSE(col_name) - f.write_token_start(&prefix)?; - self.expression.emit(t, f)?; - f.write_token_start(&suffix) - } - } - } -} - -/// A "store" expression, which transforms an SQL value from a "memory" type -/// (eg "UUID") to a "storage" type (eg "VARCHAR"). Used for databases like -/// Trino, where the storage types for a given connector may be more limited -/// than the standard Trino memory types. -/// -/// These are not found in the original parsed AST, but are added while -/// transforming the AST. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)] -pub struct StoreExpression { - /// Inferred memory type. - #[emit(skip)] - #[to_tokens(skip)] - #[drive(skip)] - pub memory_type: Option, - - /// Our underlying expression. - pub expression: Box, -} - -impl Emit for StoreExpression { - fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> ::std::io::Result<()> { - match t { - Target::BigQuery => { - f.write_token_start("%STORE(")?; - self.expression.emit(t, f)?; - f.write_token_start(")") - } - Target::Trino(connector_type) => { - let bq_memory_type = self - .memory_type - .as_ref() - .expect("memory_type should have been filled in by type inference"); - - // If our bq_memory_type is NULL, we don't need to do any transforms because - // NULL is NULL in both storage and memory types and dbcrossbar_trino doesn't - // support NULL as a memory type. - if let ValueType::Simple(SimpleType::Null) = bq_memory_type { - self.expression.emit(t, f) - } else { - let trino_memory_type = - TrinoDataType::try_from(bq_memory_type).map_err(io::Error::other)?; - let transform = connector_type.storage_transform_for(&trino_memory_type); - let (prefix, suffix) = transform.store_prefix_and_suffix(); - - f.write_token_start(&prefix)?; - self.expression.emit(t, f)?; - f.write_token_start(&suffix) - } - } - } - } -} - /// An `AS` alias. #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] pub struct Alias { @@ -2446,7 +2447,7 @@ peg::parser! { // Things from here down might start with arbitrary identifiers, so // we need to be careful about the order. function_call:function_call() { Expression::FunctionCall(function_call) } - column_name:name() { Expression::Name(column_name) } + column_name:name() { Expression::Name(NameExpression { load_to_memory_type: None, name: column_name }) } } rule interval_expression() -> IntervalExpression diff --git a/src/infer/contains_aggregate.rs b/src/infer/contains_aggregate.rs index 65076eb..230c544 100644 --- a/src/infer/contains_aggregate.rs +++ b/src/infer/contains_aggregate.rs @@ -65,6 +65,7 @@ impl ContainsAggregate for ast::Expression { ast::Expression::BoolValue(_) => false, ast::Expression::Null(_) => false, ast::Expression::Name(_) => false, + ast::Expression::Store(store_expr) => store_expr.contains_aggregate(scope), ast::Expression::Cast(cast) => cast.contains_aggregate(scope), ast::Expression::Is(is) => is.contains_aggregate(scope), ast::Expression::In(in_expr) => in_expr.contains_aggregate(scope), @@ -91,8 +92,6 @@ impl ContainsAggregate for ast::Expression { // Putting an aggregate here would be very weird. Do not allow it // until forced to do so. ast::Expression::FieldAccess(_) => false, - ast::Expression::Load(load_expr) => load_expr.contains_aggregate(scope), - ast::Expression::Store(store_expr) => store_expr.contains_aggregate(scope), } } } @@ -269,12 +268,6 @@ impl ContainsAggregate for ast::IndexOffset { } } -impl ContainsAggregate for ast::LoadExpression { - fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { - self.expression.contains_aggregate(scope) - } -} - impl ContainsAggregate for ast::StoreExpression { fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { self.expression.contains_aggregate(scope) diff --git a/src/infer/mod.rs b/src/infer/mod.rs index 7cc9aba..cbd0d15 100644 --- a/src/infer/mod.rs +++ b/src/infer/mod.rs @@ -676,7 +676,7 @@ impl InferTypes for ast::GroupBy { for expr in self.expressions.node_iter_mut() { let _ty = expr.infer_types(scope)?; match expr { - Expression::Name(name) => { + Expression::Name(ast::NameExpression { name, .. }) => { group_by_names.push(name.clone()); } _ => { @@ -719,7 +719,8 @@ impl InferTypes for ast::Expression { ast::Expression::Literal(Literal { value, .. }) => value.infer_types(&()), ast::Expression::BoolValue(_) => Ok(ArgumentType::bool()), ast::Expression::Null { .. } => Ok(ArgumentType::null()), - ast::Expression::Name(name) => name.infer_types(scope), + ast::Expression::Name(name_expr) => name_expr.infer_types(scope), + ast::Expression::Store(store_expr) => store_expr.infer_types(scope), ast::Expression::Cast(cast) => cast.infer_types(scope), ast::Expression::Is(is) => is.infer_types(scope), ast::Expression::In(in_expr) => in_expr.infer_types(scope), @@ -744,8 +745,6 @@ impl InferTypes for ast::Expression { ast::Expression::FunctionCall(fcall) => fcall.infer_types(scope), ast::Expression::Index(index) => index.infer_types(scope), ast::Expression::FieldAccess(field_access) => field_access.infer_types(scope), - ast::Expression::Load(load_expr) => load_expr.infer_types(scope), - ast::Expression::Store(store_expr) => store_expr.infer_types(scope), } } } @@ -774,6 +773,44 @@ impl InferTypes for Ident { } } +impl InferTypes for ast::NameExpression { + type Scope = ColumnSetScope; + type Output = ArgumentType; + + /// `self.name` may be a bare type `?T`, an `Agg` type (possibly + /// nested), or type involving `Stored<..>`, such as `Stored`, + /// `Agg>` or even `Agg>>` to any depth. + /// + /// We have two jobs: + /// + /// 1. Record what type we need to load from, if any. In the examples above, + /// this would be `?T`. + /// 2. Infer the type after the load, which removed `Stored<..>` but keeps + /// `Agg<..>` if present. + /// + /// | Name type | Load to memory type | Inferrred type | + /// |------------------------|---------------------|----------------| + /// | `?T` | `None` | `?T` | + /// | `Agg` | `None` | `Agg` | + /// | `Agg>` | `None` | `Agg>` | + /// | `Stored` | `?T` | `?T` | + /// | `Agg>` | `?T` | `Agg` | + /// | `Agg>>` | `?T` | `Agg>` | + fn infer_types(&mut self, scope: &Self::Scope) -> Result { + let inferred_type = self.name.infer_types(scope)?; + // Do we need to perform a load operation? + if let Some(load_to_memory_type) = inferred_type.load_to_memory_type() { + // Record this for `emit` to use. + self.load_to_memory_type = Some(load_to_memory_type.to_owned()); + // Remove `Stored<..>` from our type. + inferred_type.type_after_load() + } else { + self.load_to_memory_type = None; + Ok(inferred_type) + } + } +} + impl InferTypes for ast::Name { type Scope = ColumnSetScope; type Output = ArgumentType; @@ -823,6 +860,25 @@ impl InferTypes for ast::Name { } } +impl InferTypes for ast::StoreExpression { + type Scope = ColumnSetScope; + type Output = ArgumentType; + + /// `self.expression` should have type `?T`, and we return `Stored`. + /// + /// For example, `?T` might map to UUID, and `Stored` might map to + /// VARCHAR, but that's someone else's problem. We only deal with this in + /// the abstract. + fn infer_types(&mut self, scope: &Self::Scope) -> Result { + let inferred_type = self.expression.infer_types(scope)?; + let value_type = inferred_type.expect_value_type(&self.expression)?; + // Record this for `emit` to use if needed. + self.memory_type = Some(value_type.to_owned()); + // Return the Stored for our original ?T. + Ok(ArgumentType::Stored(value_type.to_owned())) + } +} + impl InferTypes for ast::Cast { type Scope = ColumnSetScope; type Output = ArgumentType; @@ -1346,7 +1402,7 @@ impl InferTypes for ast::PartitionBy { let mut partition_by_names = vec![]; for expr in self.expressions.node_iter_mut() { match expr { - ast::Expression::Name(name) => { + ast::Expression::Name(ast::NameExpression { name, .. }) => { scope.get_argument_type(name)?; partition_by_names.push(name.clone()); } @@ -1387,47 +1443,6 @@ impl InferTypes for ast::FieldAccessExpression { } } -impl InferTypes for ast::LoadExpression { - type Scope = ColumnSetScope; - type Output = ArgumentType; - - /// `self.expression` should have type `Stored`, and we return `?T`. - /// - /// For example, `?T` might map to UUID, and `Stored` might map to - /// VARCHAR, but that's someone else's problem. We only deal with this in - /// the abstract. - fn infer_types(&mut self, scope: &Self::Scope) -> Result { - let inferred_type = self.expression.infer_types(scope)?; - // Nobody should ever call us on any argument that doesn't have type - // `Stored`, because we're the load operation. - let value_type = - inferred_type.expect_stored_type_and_return_value_type(&self.expression)?; - // Record this for `emit` to use if needed. - self.memory_type = Some(value_type.to_owned()); - // Return the `?T` from our original `Stored`. - Ok(ArgumentType::Value(value_type.to_owned())) - } -} - -impl InferTypes for ast::StoreExpression { - type Scope = ColumnSetScope; - type Output = ArgumentType; - - /// `self.expression` should have type `?T`, and we return `Stored`. - /// - /// For example, `?T` might map to UUID, and `Stored` might map to - /// VARCHAR, but that's someone else's problem. We only deal with this in - /// the abstract. - fn infer_types(&mut self, scope: &Self::Scope) -> Result { - let inferred_type = self.expression.infer_types(scope)?; - let value_type = inferred_type.expect_value_type(&self.expression)?; - // Record this for `emit` to use if needed. - self.memory_type = Some(value_type.to_owned()); - // Return the Stored for our original ?T. - Ok(ArgumentType::Stored(value_type.to_owned())) - } -} - /// Figure out whether an expression defines an implicit column name. pub trait InferColumnName { /// Infer the column name, if any. @@ -1446,15 +1461,19 @@ impl InferColumnName for Option { impl InferColumnName for ast::Expression { fn infer_column_name(&mut self) -> Option { match self { - ast::Expression::Name(name) => { - let (_table, col) = name.split_table_and_column(); - Some(col) - } + ast::Expression::Name(name) => name.infer_column_name(), _ => None, } } } +impl InferColumnName for ast::NameExpression { + fn infer_column_name(&mut self) -> Option { + let (_table, col) = self.name.split_table_and_column(); + Some(col) + } +} + impl InferColumnName for ast::Alias { fn infer_column_name(&mut self) -> Option { Some(self.ident.clone()) diff --git a/src/types.rs b/src/types.rs index 1cd90d4..3d1d66e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -220,26 +220,6 @@ impl ArgumentType { } } - /// Expect a [`ArgumentType::Stored`] and return the `ValueType` it contains. - pub fn expect_stored_type_and_return_value_type( - &self, - spanned: &dyn Spanned, - ) -> Result<&ValueType> { - match self { - ArgumentType::Value(_) => Err(Error::annotated( - format!("expected stored type, found in-memory value type {}", self), - spanned.span(), - "type mismatch", - )), - ArgumentType::Stored(t) => Ok(t), - ArgumentType::Aggregating(_) => Err(Error::annotated( - format!("expected stored type, found aggregate type {}", self), - spanned.span(), - "type mismatch", - )), - } - } - /// Expect a [`SimpleType`]. pub fn expect_simple_type(&self, spanned: &dyn Spanned) -> Result<&SimpleType> { match self { @@ -334,6 +314,42 @@ impl ArgumentType { } } +// Methods which only work after we've resolved type variables. +impl ArgumentType { + /// Get the type we should load as, if we are stored. + /// + /// Conceptually, we load before we aggregate, so if you pass + /// `Agg>>`, this will return `Some(T)`. We need to know this + /// type to generate appropriate loading code (which does not care about + /// aggregation). + /// + /// This will normally be followed up by calling [`Self::type_after_load`]. + pub fn load_to_memory_type(&self) -> Option { + match self { + ArgumentType::Value(_) => None, + ArgumentType::Stored(value_type) => Some(value_type.clone()), + ArgumentType::Aggregating(argument_type) => argument_type.load_to_memory_type(), + } + } + + /// The type of this argument after we've loaded it into memory. This will included any + /// aggregations. + /// + /// For example, if we have `Agg>>`, this will return + /// `Agg>`. + /// + /// This is normally called after [`Self::load_to_memory_type`]. + pub fn type_after_load(&self) -> Result { + match self { + ArgumentType::Value(_) => Err(format_err!("cannot load a value type")), + ArgumentType::Stored(value_type) => Ok(ArgumentType::Value(value_type.clone())), + ArgumentType::Aggregating(argument_type) => Ok(ArgumentType::Aggregating(Box::new( + argument_type.type_after_load()?, + ))), + } + } +} + impl Unify for ArgumentType { type Resolved = ArgumentType;