From cda3186e9b83ead94a0ab995bee4397ba0978045 Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Mon, 30 Oct 2023 07:03:24 -0400 Subject: [PATCH] Infer IS TRUE/FALSE and CASE We fix a bug in our IS TRUE/FALSE inference, and rewrite it to use sql_quote. We also decide to implement both versions of CASE. --- src/ast.rs | 38 ++++--------------- src/drivers/snowflake/mod.rs | 1 + src/drivers/trino/mod.rs | 1 + src/infer.rs | 63 ++++++++++++++++++++++--------- src/scope.rs | 8 ---- src/transforms/is_bool_to_case.rs | 41 ++++++++++++++++++++ src/transforms/mod.rs | 2 + src/types.rs | 27 +++++++++++-- src/unification.rs | 15 ++++++++ tests/sql/operators/case.sql | 8 +++- tests/sql/operators/is.sql | 6 +++ 11 files changed, 149 insertions(+), 61 deletions(-) create mode 100644 src/transforms/is_bool_to_case.rs diff --git a/src/ast.rs b/src/ast.rs index 802de51..71290f8 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -823,7 +823,7 @@ impl Emit for CastType { } /// An `IS` expression. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] pub struct IsExpression { pub left: Box, pub is_token: Keyword, @@ -831,35 +831,6 @@ pub struct IsExpression { pub predicate: IsExpressionPredicate, } -impl Emit for IsExpression { - fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> { - match (&self.predicate, t) { - // BigQuery allows anything. - (_, Target::BigQuery) => self.emit_default(t, f), - // `UNKNOWN` will be translated to `NULL` everywhere else. - (IsExpressionPredicate::Null(_), _) | (IsExpressionPredicate::Unknown(_), _) => { - self.emit_default(t, f) - } - // `TRUE` and `FALSE` work on SQLite3. - (IsExpressionPredicate::True(_), Target::SQLite3) - | (IsExpressionPredicate::False(_), Target::SQLite3) => self.emit_default(t, f), - // For everyone else, we need to use CASE. - (IsExpressionPredicate::True(keyword), _) - | (IsExpressionPredicate::False(keyword), _) => { - f.write_token_start("CASE")?; - self.left.emit(t, f)?; - f.write_token_start("WHEN")?; - keyword.emit(t, f)?; - f.write_token_start("THEN")?; - f.write_token_start("TRUE")?; - f.write_token_start("ELSE")?; - f.write_token_start("FALSE")?; - f.write_token_start("END") - } - } - } -} - /// An `IS` predicate. #[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)] pub enum IsExpressionPredicate { @@ -959,6 +930,7 @@ pub struct IfExpression { #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] pub struct CaseExpression { pub case_token: Keyword, + pub case_expr: Option>, pub when_clauses: Vec, pub else_clause: Option, pub end_token: Keyword, @@ -2035,9 +2007,13 @@ peg::parser! { }) } -- - case_token:k("CASE") when_clauses:(case_when_clause()*) else_clause:case_else_clause()? end_token:k("END") { + case_token:k("CASE") + case_expr:expression()? + when_clauses:(case_when_clause()*) + else_clause:case_else_clause()? end_token:k("END") { Expression::Case(CaseExpression { case_token, + case_expr: case_expr.map(Box::new), when_clauses, else_clause, end_token, diff --git a/src/drivers/snowflake/mod.rs b/src/drivers/snowflake/mod.rs index b1db0c5..8a8c282 100644 --- a/src/drivers/snowflake/mod.rs +++ b/src/drivers/snowflake/mod.rs @@ -250,6 +250,7 @@ impl Driver for SnowflakeDriver { Box::new(transforms::CountifToCase), Box::new(transforms::IfToCase), Box::new(transforms::IndexFromZero), + Box::new(transforms::IsBoolToCase), Box::new(transforms::RenameFunctions::new( &FUNCTION_NAMES, &UDFS, diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs index 0b74997..b6d7e73 100644 --- a/src/drivers/trino/mod.rs +++ b/src/drivers/trino/mod.rs @@ -175,6 +175,7 @@ impl Driver for TrinoDriver { Box::new(transforms::CountifToCase), Box::new(transforms::IndexFromOne), Box::new(transforms::InUnnestToInSelect), + Box::new(transforms::IsBoolToCase), Box::new(transforms::OrReplaceToDropIfExists), Box::new(transforms::RenameFunctions::new( &FUNCTION_NAMES, diff --git a/src/infer.rs b/src/infer.rs index 4b45965..36036de 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -1,14 +1,12 @@ //! Our type inference subsystem. -// This is work in progress. -#![allow(dead_code)] - use crate::{ ast, errors::{Error, Result}, scope::{CaseInsensitiveIdent, Scope, ScopeHandle}, tokenizer::{Ident, Literal, LiteralValue, Spanned}, types::{ArgumentType, ColumnType, SimpleType, TableType, Type, ValueType}, + unification::{UnificationTable, Unify}, }; // TODO: Remember this rather scary example. Verify BigQuery supports it @@ -377,12 +375,10 @@ impl InferTypes for ast::Expression { type Type = ArgumentType; fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { - let arg = ArgumentType::Value; - let arg_simple = |ty| arg(ValueType::Simple(ty)); match self { ast::Expression::Literal(Literal { value, .. }) => value.infer_types(scope), - ast::Expression::BoolValue(_) => Ok((arg_simple(SimpleType::Bool), scope.clone())), - ast::Expression::Null { .. } => Ok((arg_simple(SimpleType::Null), scope.clone())), + ast::Expression::BoolValue(_) => Ok((ArgumentType::bool(), scope.clone())), + ast::Expression::Null { .. } => Ok((ArgumentType::null(), scope.clone())), ast::Expression::ColumnName(ident) => ident.infer_types(scope), ast::Expression::TableAndColumnName(name) => name.infer_types(scope), ast::Expression::Cast(cast) => cast.infer_types(scope), @@ -486,14 +482,12 @@ impl InferTypes for ast::IsExpressionPredicate { // types for the %IS primitive will use this to verify that the left // argument and predicate are compatible. match self { - ast::IsExpressionPredicate::Null(_) | ast::IsExpressionPredicate::Unknown(_) => Ok(( - ArgumentType::Value(ValueType::Simple(SimpleType::Null)), - scope.clone(), - )), - ast::IsExpressionPredicate::True(_) | ast::IsExpressionPredicate::False(_) => Ok(( - ArgumentType::Value(ValueType::Simple(SimpleType::Bool)), - scope.clone(), - )), + ast::IsExpressionPredicate::Null(_) | ast::IsExpressionPredicate::Unknown(_) => { + Ok((ArgumentType::null(), scope.clone())) + } + ast::IsExpressionPredicate::True(_) | ast::IsExpressionPredicate::False(_) => { + Ok((ArgumentType::bool(), scope.clone())) + } } } } @@ -520,7 +514,7 @@ impl InferTypes for ast::InExpression { impl InferTypes for ast::InValueSet { type Type = ArgumentType; - fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { + fn infer_types(&mut self, _scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { match self { ast::InValueSet::QueryExpression { .. } => Err(nyi(self, "IN subquery")), ast::InValueSet::ExpressionList { .. } => Err(nyi(self, "IN expression list")), @@ -587,7 +581,42 @@ impl InferTypes for ast::CaseExpression { type Type = ArgumentType; fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { - Err(nyi(self, "CASE")) + // CASE is basically two different constructs, depending on whether + // there is a CASE expression or not. We handle them separately. + let mut table = UnificationTable::default(); + if let Some(case_expr) = &mut self.case_expr { + let match_tv = table.type_var("M", &self.case_token)?; + match_tv.unify(&case_expr.infer_types(scope)?.0, &mut table, case_expr)?; + let result_tv = table.type_var("R", &self.case_token)?; + for c in &mut self.when_clauses { + match_tv.unify(&c.condition.infer_types(scope)?.0, &mut table, &c.condition)?; + result_tv.unify(&c.result.infer_types(scope)?.0, &mut table, &c.result)?; + } + if let Some(else_clause) = &mut self.else_clause { + let else_expr = &mut else_clause.result; + result_tv.unify(&else_expr.infer_types(scope)?.0, &mut table, else_expr)?; + } else { + result_tv.unify(&ArgumentType::null(), &mut table, self)?; + } + Ok((match_tv.resolve(&table, self)?, scope.clone())) + } else { + let bool_ty = ArgumentType::bool(); + let result_tv = table.type_var("R", &self.case_token)?; + for c in &mut self.when_clauses { + c.condition + .infer_types(scope)? + .0 + .expect_subtype_of(&bool_ty, c)?; + result_tv.unify(&c.result.infer_types(scope)?.0, &mut table, &c.result)?; + } + if let Some(else_clause) = &mut self.else_clause { + let else_expr = &mut else_clause.result; + result_tv.unify(&else_expr.infer_types(scope)?.0, &mut table, else_expr)?; + } else { + result_tv.unify(&ArgumentType::null(), &mut table, self)?; + } + Ok((result_tv.resolve(&table, self)?, scope.clone())) + } } } diff --git a/src/scope.rs b/src/scope.rs index f83628c..c9aed90 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,8 +1,5 @@ //! Namespace for SQL. -// This is work in progress. -#![allow(dead_code)] - use std::{collections::BTreeMap, fmt, hash, sync::Arc}; use crate::{ @@ -26,11 +23,6 @@ impl CaseInsensitiveIdent { pub fn new(name: &str, span: Span) -> Self { Ident::new(name, span).into() } - - /// Get the underlying identifier. - pub fn ident(&self) -> &Ident { - &self.ident - } } impl From for CaseInsensitiveIdent { diff --git a/src/transforms/is_bool_to_case.rs b/src/transforms/is_bool_to_case.rs new file mode 100644 index 0000000..3453ba2 --- /dev/null +++ b/src/transforms/is_bool_to_case.rs @@ -0,0 +1,41 @@ +use derive_visitor::{DriveMut, VisitorMut}; +use joinery_macros::sql_quote; + +use crate::{ + ast::{self, Expression, IsExpression, IsExpressionPredicate}, + errors::Result, +}; + +use super::{Transform, TransformExtra}; + +/// Transform `expr IS [NOT] (TRUE|FALSE)` into a portable +/// `CASE` expression. +#[derive(VisitorMut)] +#[visitor(Expression(enter))] +pub struct IsBoolToCase; + +impl IsBoolToCase { + fn enter_expression(&mut self, expr: &mut Expression) { + if let Expression::Is(IsExpression { + left, + not_token, + predicate: IsExpressionPredicate::True(pred) | IsExpressionPredicate::False(pred), + .. + }) = expr + { + let replacement = sql_quote! { + CASE #not_token #left WHEN #pred THEN TRUE ELSE FALSE END + } + .try_into_expression() + .expect("generated SQL should always parse"); + *expr = replacement; + } + } +} + +impl Transform for IsBoolToCase { + fn transform(mut self: Box, sql_program: &mut ast::SqlProgram) -> Result { + sql_program.drive_mut(self.as_mut()); + Ok(TransformExtra::default()) + } +} diff --git a/src/transforms/mod.rs b/src/transforms/mod.rs index 429c3e6..7db0ef9 100644 --- a/src/transforms/mod.rs +++ b/src/transforms/mod.rs @@ -15,6 +15,7 @@ pub use self::{ in_unnest_to_in_select::InUnnestToInSelect, index_from_one::IndexFromOne, index_from_zero::IndexFromZero, + is_bool_to_case::IsBoolToCase, or_replace_to_drop_if_exists::OrReplaceToDropIfExists, rename_functions::{RenameFunctions, Udf}, standardize_current_date::StandardizeCurrentDate, @@ -28,6 +29,7 @@ mod if_to_case; mod in_unnest_to_in_select; mod index_from_one; mod index_from_zero; +mod is_bool_to_case; mod or_replace_to_drop_if_exists; mod rename_functions; mod standardize_current_date; diff --git a/src/types.rs b/src/types.rs index aacadca..9de2cef 100644 --- a/src/types.rs +++ b/src/types.rs @@ -12,9 +12,6 @@ //! does not support `ARRAY>`, only `ARRAY>>`, and that //! `STRUCT` fields have optional names. -// Work in progress. -#![allow(dead_code)] - use std::fmt; use peg::{error::ParseError, str::LineCol}; @@ -122,6 +119,7 @@ impl Type { } /// Convert this type into a [`ValueType`], if possible. + #[allow(dead_code)] pub fn try_as_value_type(&self, spanned: &dyn Spanned) -> Result<&ValueType> { match self { Type::Argument(ArgumentType::Value(t)) => Ok(t), @@ -178,6 +176,16 @@ pub enum ArgumentType { } impl ArgumentType { + /// Create a NULL type. + pub fn null() -> Self { + ArgumentType::Value(ValueType::Simple(SimpleType::Null)) + } + + /// Create a BOOL type. + pub fn bool() -> Self { + ArgumentType::Value(ValueType::Simple(SimpleType::Bool)) + } + /// Expect a [`ValueType`]. pub fn expect_value_type(&self, spanned: &dyn Spanned) -> Result<&ValueType> { match self { @@ -213,6 +221,18 @@ impl ArgumentType { } } + /// Return an error if we are not a subtype of `other`. + pub fn expect_subtype_of(&self, other: &ArgumentType, spanned: &dyn Spanned) -> Result<()> { + if !self.is_subtype_of(other) { + return Err(Error::annotated( + format!("expected {}, found {}", other, self), + spanned.span(), + "type mismatch", + )); + } + Ok(()) + } + /// Find a common supertype of two types. Returns `None` if the only common /// super type would be top (⊤), which isn't part of our type system. pub fn common_supertype<'a>(&'a self, other: &'a ArgumentType) -> Option> { @@ -319,6 +339,7 @@ impl ValueType { } /// Return an error if we are not a subtype of `other`. + #[allow(dead_code)] pub fn expect_subtype_of(&self, other: &ValueType, spanned: &dyn Spanned) -> Result<()> { if !self.is_subtype_of(other) { return Err(Error::annotated( diff --git a/src/unification.rs b/src/unification.rs index f87d631..dd605b0 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -61,6 +61,21 @@ impl UnificationTable { Ok(()) } + /// Create a new type variable, declare it, and return an `ArgumentType`. + /// + /// This is handy when we're forced to implement custom unification logic. + pub fn type_var( + &mut self, + name: impl Into, + spanned: &dyn Spanned, + ) -> Result> { + let var = TypeVar::new(name)?; + self.declare(var.clone(), spanned)?; + Ok(ArgumentType::Value(ValueType::Simple( + SimpleType::Parameter(var), + ))) + } + /// Update a type variable to a new type. pub fn update( &mut self, diff --git a/tests/sql/operators/case.sql b/tests/sql/operators/case.sql index 78b9ad1..cffb51c 100644 --- a/tests/sql/operators/case.sql +++ b/tests/sql/operators/case.sql @@ -5,7 +5,9 @@ SELECT CASE WHEN FALSE THEN 1 ELSE 2 END AS case_when_false, CASE WHEN NULL THEN 1 ELSE 2 END AS case_when_null, CASE WHEN FALSE THEN 1 END AS case_when_false_no_else, - CASE WHEN FALSE THEN 1 WHEN TRUE THEN 2 ELSE 3 END AS case_when_false_true_else; + CASE WHEN FALSE THEN 1 WHEN TRUE THEN 2 ELSE 3 END AS case_when_false_true_else, + CASE 'a' WHEN 'a' THEN 1 ELSE 2 END AS case_string_when, + CASE 'a' WHEN 'b' THEN 1 END AS case_string_when_no_else; CREATE OR REPLACE TABLE __expected1 ( case_when_true INT64, @@ -13,6 +15,8 @@ CREATE OR REPLACE TABLE __expected1 ( case_when_null INT64, case_when_false_no_else INT64, case_when_false_true_else INT64, + case_string_when INT64, + case_string_when_no_else INT64, ); INSERT INTO __expected1 VALUES - (1, 2, 2, NULL, 2); + (1, 2, 2, NULL, 2, 1, NULL); diff --git a/tests/sql/operators/is.sql b/tests/sql/operators/is.sql index bbbc5d9..c1b5613 100644 --- a/tests/sql/operators/is.sql +++ b/tests/sql/operators/is.sql @@ -6,8 +6,10 @@ SELECT NULL IS NULL AS null_is_null, NULL IS NOT NULL AS null_is_not_null, TRUE IS TRUE AS true_is_true, + FALSE IS NOT TRUE AS false_is_not_true, NULL IS TRUE AS null_is_true, FALSE IS FALSE AS false_is_false, + TRUE IS NOT FALSE AS true_is_not_false, NULL IS FALSE AS null_is_false, NULL IS UNKNOWN AS null_is_unknown, NULL IS NOT UNKNOWN AS null_is_not_unknown, @@ -19,8 +21,10 @@ CREATE OR REPLACE TABLE __expected1 ( null_is_null BOOL, null_is_not_null BOOL, true_is_true BOOL, + false_is_not_true BOOL, null_is_true BOOL, false_is_false BOOL, + true_is_not_false BOOL, null_is_false BOOL, null_is_unknown BOOL, null_is_not_unknown BOOL, @@ -32,8 +36,10 @@ INSERT INTO __expected1 VALUES ( TRUE, -- null_is_null FALSE, -- null_is_not_null TRUE, -- true_is_true + TRUE, -- false_is_not_true FALSE, -- null_is_true TRUE, -- false_is_false + TRUE, -- true_is_not_false FALSE, -- null_is_false TRUE, -- null_is_unknown FALSE, -- null_is_not_unknown