Skip to content

Commit

Permalink
Infer IS TRUE/FALSE and CASE
Browse files Browse the repository at this point in the history
We fix a bug in our IS TRUE/FALSE inference, and rewrite it to use
sql_quote. We also decide to implement both versions of CASE.
  • Loading branch information
emk committed Oct 30, 2023
1 parent d1fe5cd commit cda3186
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 61 deletions.
38 changes: 7 additions & 31 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,43 +823,14 @@ impl Emit for CastType {
}

/// An `IS` expression.
#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)]
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct IsExpression {
pub left: Box<Expression>,
pub is_token: Keyword,
pub not_token: Option<Keyword>,
pub predicate: IsExpressionPredicate,
}

impl Emit for IsExpression {
fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> {
match (&self.predicate, t) {
// BigQuery allows anything.
(_, Target::BigQuery) => self.emit_default(t, f),
// `UNKNOWN` will be translated to `NULL` everywhere else.
(IsExpressionPredicate::Null(_), _) | (IsExpressionPredicate::Unknown(_), _) => {
self.emit_default(t, f)
}
// `TRUE` and `FALSE` work on SQLite3.
(IsExpressionPredicate::True(_), Target::SQLite3)
| (IsExpressionPredicate::False(_), Target::SQLite3) => self.emit_default(t, f),
// For everyone else, we need to use CASE.
(IsExpressionPredicate::True(keyword), _)
| (IsExpressionPredicate::False(keyword), _) => {
f.write_token_start("CASE")?;
self.left.emit(t, f)?;
f.write_token_start("WHEN")?;
keyword.emit(t, f)?;
f.write_token_start("THEN")?;
f.write_token_start("TRUE")?;
f.write_token_start("ELSE")?;
f.write_token_start("FALSE")?;
f.write_token_start("END")
}
}
}
}

/// An `IS` predicate.
#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)]
pub enum IsExpressionPredicate {
Expand Down Expand Up @@ -959,6 +930,7 @@ pub struct IfExpression {
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct CaseExpression {
pub case_token: Keyword,
pub case_expr: Option<Box<Expression>>,
pub when_clauses: Vec<CaseWhenClause>,
pub else_clause: Option<CaseElseClause>,
pub end_token: Keyword,
Expand Down Expand Up @@ -2035,9 +2007,13 @@ peg::parser! {
})
}
--
case_token:k("CASE") when_clauses:(case_when_clause()*) else_clause:case_else_clause()? end_token:k("END") {
case_token:k("CASE")
case_expr:expression()?
when_clauses:(case_when_clause()*)
else_clause:case_else_clause()? end_token:k("END") {
Expression::Case(CaseExpression {
case_token,
case_expr: case_expr.map(Box::new),
when_clauses,
else_clause,
end_token,
Expand Down
1 change: 1 addition & 0 deletions src/drivers/snowflake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ impl Driver for SnowflakeDriver {
Box::new(transforms::CountifToCase),
Box::new(transforms::IfToCase),
Box::new(transforms::IndexFromZero),
Box::new(transforms::IsBoolToCase),
Box::new(transforms::RenameFunctions::new(
&FUNCTION_NAMES,
&UDFS,
Expand Down
1 change: 1 addition & 0 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ impl Driver for TrinoDriver {
Box::new(transforms::CountifToCase),
Box::new(transforms::IndexFromOne),
Box::new(transforms::InUnnestToInSelect),
Box::new(transforms::IsBoolToCase),
Box::new(transforms::OrReplaceToDropIfExists),
Box::new(transforms::RenameFunctions::new(
&FUNCTION_NAMES,
Expand Down
63 changes: 46 additions & 17 deletions src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
//! Our type inference subsystem.
// This is work in progress.
#![allow(dead_code)]

use crate::{
ast,
errors::{Error, Result},
scope::{CaseInsensitiveIdent, Scope, ScopeHandle},
tokenizer::{Ident, Literal, LiteralValue, Spanned},
types::{ArgumentType, ColumnType, SimpleType, TableType, Type, ValueType},
unification::{UnificationTable, Unify},
};

// TODO: Remember this rather scary example. Verify BigQuery supports it
Expand Down Expand Up @@ -377,12 +375,10 @@ impl InferTypes for ast::Expression {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let arg = ArgumentType::Value;
let arg_simple = |ty| arg(ValueType::Simple(ty));
match self {
ast::Expression::Literal(Literal { value, .. }) => value.infer_types(scope),
ast::Expression::BoolValue(_) => Ok((arg_simple(SimpleType::Bool), scope.clone())),
ast::Expression::Null { .. } => Ok((arg_simple(SimpleType::Null), scope.clone())),
ast::Expression::BoolValue(_) => Ok((ArgumentType::bool(), scope.clone())),
ast::Expression::Null { .. } => Ok((ArgumentType::null(), scope.clone())),
ast::Expression::ColumnName(ident) => ident.infer_types(scope),
ast::Expression::TableAndColumnName(name) => name.infer_types(scope),
ast::Expression::Cast(cast) => cast.infer_types(scope),
Expand Down Expand Up @@ -486,14 +482,12 @@ impl InferTypes for ast::IsExpressionPredicate {
// types for the %IS primitive will use this to verify that the left
// argument and predicate are compatible.
match self {
ast::IsExpressionPredicate::Null(_) | ast::IsExpressionPredicate::Unknown(_) => Ok((
ArgumentType::Value(ValueType::Simple(SimpleType::Null)),
scope.clone(),
)),
ast::IsExpressionPredicate::True(_) | ast::IsExpressionPredicate::False(_) => Ok((
ArgumentType::Value(ValueType::Simple(SimpleType::Bool)),
scope.clone(),
)),
ast::IsExpressionPredicate::Null(_) | ast::IsExpressionPredicate::Unknown(_) => {
Ok((ArgumentType::null(), scope.clone()))
}
ast::IsExpressionPredicate::True(_) | ast::IsExpressionPredicate::False(_) => {
Ok((ArgumentType::bool(), scope.clone()))
}
}
}
}
Expand All @@ -520,7 +514,7 @@ impl InferTypes for ast::InExpression {
impl InferTypes for ast::InValueSet {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
fn infer_types(&mut self, _scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
match self {
ast::InValueSet::QueryExpression { .. } => Err(nyi(self, "IN subquery")),
ast::InValueSet::ExpressionList { .. } => Err(nyi(self, "IN expression list")),
Expand Down Expand Up @@ -587,7 +581,42 @@ impl InferTypes for ast::CaseExpression {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
Err(nyi(self, "CASE"))
// CASE is basically two different constructs, depending on whether
// there is a CASE expression or not. We handle them separately.
let mut table = UnificationTable::default();
if let Some(case_expr) = &mut self.case_expr {
let match_tv = table.type_var("M", &self.case_token)?;
match_tv.unify(&case_expr.infer_types(scope)?.0, &mut table, case_expr)?;
let result_tv = table.type_var("R", &self.case_token)?;
for c in &mut self.when_clauses {
match_tv.unify(&c.condition.infer_types(scope)?.0, &mut table, &c.condition)?;
result_tv.unify(&c.result.infer_types(scope)?.0, &mut table, &c.result)?;
}
if let Some(else_clause) = &mut self.else_clause {
let else_expr = &mut else_clause.result;
result_tv.unify(&else_expr.infer_types(scope)?.0, &mut table, else_expr)?;
} else {
result_tv.unify(&ArgumentType::null(), &mut table, self)?;
}
Ok((match_tv.resolve(&table, self)?, scope.clone()))
} else {
let bool_ty = ArgumentType::bool();
let result_tv = table.type_var("R", &self.case_token)?;
for c in &mut self.when_clauses {
c.condition
.infer_types(scope)?
.0
.expect_subtype_of(&bool_ty, c)?;
result_tv.unify(&c.result.infer_types(scope)?.0, &mut table, &c.result)?;
}
if let Some(else_clause) = &mut self.else_clause {
let else_expr = &mut else_clause.result;
result_tv.unify(&else_expr.infer_types(scope)?.0, &mut table, else_expr)?;
} else {
result_tv.unify(&ArgumentType::null(), &mut table, self)?;
}
Ok((result_tv.resolve(&table, self)?, scope.clone()))
}
}
}

Expand Down
8 changes: 0 additions & 8 deletions src/scope.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
//! Namespace for SQL.
// This is work in progress.
#![allow(dead_code)]

use std::{collections::BTreeMap, fmt, hash, sync::Arc};

use crate::{
Expand All @@ -26,11 +23,6 @@ impl CaseInsensitiveIdent {
pub fn new(name: &str, span: Span) -> Self {
Ident::new(name, span).into()
}

/// Get the underlying identifier.
pub fn ident(&self) -> &Ident {
&self.ident
}
}

impl From<Ident> for CaseInsensitiveIdent {
Expand Down
41 changes: 41 additions & 0 deletions src/transforms/is_bool_to_case.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use derive_visitor::{DriveMut, VisitorMut};
use joinery_macros::sql_quote;

use crate::{
ast::{self, Expression, IsExpression, IsExpressionPredicate},
errors::Result,
};

use super::{Transform, TransformExtra};

/// Transform `expr IS [NOT] (TRUE|FALSE)` into a portable
/// `CASE` expression.
#[derive(VisitorMut)]
#[visitor(Expression(enter))]
pub struct IsBoolToCase;

impl IsBoolToCase {
fn enter_expression(&mut self, expr: &mut Expression) {
if let Expression::Is(IsExpression {
left,
not_token,
predicate: IsExpressionPredicate::True(pred) | IsExpressionPredicate::False(pred),
..
}) = expr
{
let replacement = sql_quote! {
CASE #not_token #left WHEN #pred THEN TRUE ELSE FALSE END
}
.try_into_expression()
.expect("generated SQL should always parse");
*expr = replacement;
}
}
}

impl Transform for IsBoolToCase {
fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
}
}
2 changes: 2 additions & 0 deletions src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub use self::{
in_unnest_to_in_select::InUnnestToInSelect,
index_from_one::IndexFromOne,
index_from_zero::IndexFromZero,
is_bool_to_case::IsBoolToCase,
or_replace_to_drop_if_exists::OrReplaceToDropIfExists,
rename_functions::{RenameFunctions, Udf},
standardize_current_date::StandardizeCurrentDate,
Expand All @@ -28,6 +29,7 @@ mod if_to_case;
mod in_unnest_to_in_select;
mod index_from_one;
mod index_from_zero;
mod is_bool_to_case;
mod or_replace_to_drop_if_exists;
mod rename_functions;
mod standardize_current_date;
Expand Down
27 changes: 24 additions & 3 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
//! does not support `ARRAY<ARRAY<T>>`, only `ARRAY<STRUCT<ARRAY<T>>>`, and that
//! `STRUCT` fields have optional names.
// Work in progress.
#![allow(dead_code)]

use std::fmt;

use peg::{error::ParseError, str::LineCol};
Expand Down Expand Up @@ -122,6 +119,7 @@ impl<TV: TypeVarSupport> Type<TV> {
}

/// Convert this type into a [`ValueType`], if possible.
#[allow(dead_code)]
pub fn try_as_value_type(&self, spanned: &dyn Spanned) -> Result<&ValueType<TV>> {
match self {
Type::Argument(ArgumentType::Value(t)) => Ok(t),
Expand Down Expand Up @@ -178,6 +176,16 @@ pub enum ArgumentType<TV: TypeVarSupport = ResolvedTypeVarsOnly> {
}

impl<TV: TypeVarSupport> ArgumentType<TV> {
/// Create a NULL type.
pub fn null() -> Self {
ArgumentType::Value(ValueType::Simple(SimpleType::Null))
}

/// Create a BOOL type.
pub fn bool() -> Self {
ArgumentType::Value(ValueType::Simple(SimpleType::Bool))
}

/// Expect a [`ValueType`].
pub fn expect_value_type(&self, spanned: &dyn Spanned) -> Result<&ValueType<TV>> {
match self {
Expand Down Expand Up @@ -213,6 +221,18 @@ impl<TV: TypeVarSupport> ArgumentType<TV> {
}
}

/// Return an error if we are not a subtype of `other`.
pub fn expect_subtype_of(&self, other: &ArgumentType<TV>, spanned: &dyn Spanned) -> Result<()> {
if !self.is_subtype_of(other) {
return Err(Error::annotated(
format!("expected {}, found {}", other, self),
spanned.span(),
"type mismatch",
));
}
Ok(())
}

/// Find a common supertype of two types. Returns `None` if the only common
/// super type would be top (⊤), which isn't part of our type system.
pub fn common_supertype<'a>(&'a self, other: &'a ArgumentType<TV>) -> Option<ArgumentType<TV>> {
Expand Down Expand Up @@ -319,6 +339,7 @@ impl<TV: TypeVarSupport> ValueType<TV> {
}

/// Return an error if we are not a subtype of `other`.
#[allow(dead_code)]
pub fn expect_subtype_of(&self, other: &ValueType<TV>, spanned: &dyn Spanned) -> Result<()> {
if !self.is_subtype_of(other) {
return Err(Error::annotated(
Expand Down
15 changes: 15 additions & 0 deletions src/unification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ impl UnificationTable {
Ok(())
}

/// Create a new type variable, declare it, and return an `ArgumentType`.
///
/// This is handy when we're forced to implement custom unification logic.
pub fn type_var(
&mut self,
name: impl Into<String>,
spanned: &dyn Spanned,
) -> Result<ArgumentType<TypeVar>> {
let var = TypeVar::new(name)?;
self.declare(var.clone(), spanned)?;
Ok(ArgumentType::Value(ValueType::Simple(
SimpleType::Parameter(var),
)))
}

/// Update a type variable to a new type.
pub fn update(
&mut self,
Expand Down
8 changes: 6 additions & 2 deletions tests/sql/operators/case.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@ SELECT
CASE WHEN FALSE THEN 1 ELSE 2 END AS case_when_false,
CASE WHEN NULL THEN 1 ELSE 2 END AS case_when_null,
CASE WHEN FALSE THEN 1 END AS case_when_false_no_else,
CASE WHEN FALSE THEN 1 WHEN TRUE THEN 2 ELSE 3 END AS case_when_false_true_else;
CASE WHEN FALSE THEN 1 WHEN TRUE THEN 2 ELSE 3 END AS case_when_false_true_else,
CASE 'a' WHEN 'a' THEN 1 ELSE 2 END AS case_string_when,
CASE 'a' WHEN 'b' THEN 1 END AS case_string_when_no_else;

CREATE OR REPLACE TABLE __expected1 (
case_when_true INT64,
case_when_false INT64,
case_when_null INT64,
case_when_false_no_else INT64,
case_when_false_true_else INT64,
case_string_when INT64,
case_string_when_no_else INT64,
);
INSERT INTO __expected1 VALUES
(1, 2, 2, NULL, 2);
(1, 2, 2, NULL, 2, 1, NULL);
Loading

0 comments on commit cda3186

Please sign in to comment.