Skip to content

Commit

Permalink
Infer IN
Browse files Browse the repository at this point in the history
  • Loading branch information
emk committed Oct 30, 2023
1 parent cda3186 commit 3bde535
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 14 deletions.
53 changes: 43 additions & 10 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,23 +502,56 @@ impl InferTypes for ast::InExpression {
let func_ty = scope
.get_or_err(func_name)?
.try_as_function_type(func_name)?;
let arg_types = [
self.left.infer_types(scope)?.0,
self.value_set.infer_types(scope)?.0,
];
let ret_ty = func_ty.return_type_for(&arg_types, func_name)?;
let left_ty = self.left.infer_types(scope)?.0;
let value_set_ty = self.value_set.infer_types(scope)?.0;
if value_set_ty.columns.len() != 1 {
return Err(Error::annotated(
format!(
"expected value set to have one column, found {}",
value_set_ty.columns.len()
),
self.value_set.span(),
"wrong number of columns",
));
}
let elem_ty = value_set_ty.columns[0].ty.clone();
let ret_ty = func_ty.return_type_for(&[left_ty, elem_ty], func_name)?;
Ok((ret_ty, scope.clone()))
}
}

impl InferTypes for ast::InValueSet {
type Type = ArgumentType;
type Type = TableType;

fn infer_types(&mut self, _scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
match self {
ast::InValueSet::QueryExpression { .. } => Err(nyi(self, "IN subquery")),
ast::InValueSet::ExpressionList { .. } => Err(nyi(self, "IN expression list")),
ast::InValueSet::Unnest { .. } => Err(nyi(self, "IN unnest")),
ast::InValueSet::QueryExpression { query, .. } => query.infer_types(scope),
ast::InValueSet::ExpressionList {
paren1,
expressions,
..
} => {
// Create a 1-column table type.
let mut table = UnificationTable::default();
let col_ty = table.type_var("T", paren1)?;
for e in expressions.node_iter_mut() {
col_ty.unify(&e.infer_types(scope)?.0, &mut table, e)?;
}
let table_type = TableType {
columns: vec![ColumnType {
name: None,
ty: col_ty.resolve(&table, self)?,
not_null: false,
}],
};
Ok((table_type, scope.clone()))
}
ast::InValueSet::Unnest { expression, .. } => {
let array_ty = expression.infer_types(scope)?.0;
let array_ty = array_ty.expect_array_type(expression)?;
let table_ty = array_ty.unnest(expression)?;
Ok((table_ty, scope.clone()))
}
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ static BUILT_IN_FUNCTIONS: &str = "
%AND = Fn(BOOL, BOOL) -> BOOL;
%BETWEEN = Fn<?T>(?T, ?T, ?T) -> BOOL;
%IF = Fn<?T>(BOOL, ?T, ?T) -> ?T;
%IN = Fn<?T>(?T, ARRAY<?T>) -> BOOL;
-- Second argument to IN is actually TABLE<?T>, but we just do the lookup using
-- the column type.
%IN = Fn<?T>(?T, ?T) -> BOOL;
%IS = Fn<?T>(?T, NULL) -> BOOL | Fn(BOOL, BOOL) -> BOOL;
%NOT = Fn(BOOL) -> BOOL;
%OR = Fn(BOOL, BOOL) -> BOOL;
Expand Down
49 changes: 49 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ impl<TV: TypeVarSupport> ArgumentType<TV> {
}
}

/// Expect an [`ArrayType`].
pub fn expect_array_type(&self, spanned: &dyn Spanned) -> Result<&ValueType<TV>> {
match self {
ArgumentType::Value(t @ ValueType::Array(_)) => Ok(t),
_ => Err(Error::annotated(
format!("expected array type, found {}", self),
spanned.span(),
"type mismatch",
)),
}
}

/// Is this a subtype of `other`?
pub fn is_subtype_of(&self, other: &ArgumentType<TV>) -> bool {
// Value types can't be subtypes of aggregating types or vice versa,
Expand Down Expand Up @@ -380,6 +392,43 @@ impl<TV: TypeVarSupport> ValueType<TV> {
}
}

impl ValueType<ResolvedTypeVarsOnly> {
/// Unnest an array type into a table type, according to [Google's rules][unnest].
///
/// [unnest]: https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#unnest_operator
pub fn unnest(&self, spanned: &dyn Spanned) -> Result<TableType> {
match self {
// Structs unnest to tables with the same columns.
//
// TODO: JSON, too, if we ever support it.
ValueType::Array(SimpleType::Struct(s)) => Ok(TableType {
columns: s
.fields
.iter()
.map(|field| ColumnType {
name: field.name.clone(),
ty: ArgumentType::Value(field.ty.clone()),
not_null: false,
})
.collect(),
}),
// Other types unnest to tables with a single anonymous column.
ValueType::Array(elem_ty) => Ok(TableType {
columns: vec![ColumnType {
name: None,
ty: ArgumentType::Value(ValueType::Simple(elem_ty.clone())),
not_null: false,
}],
}),
_ => Err(Error::annotated(
"cannot unnest a non-array",
spanned.span(),
"type mismatch",
)),
}
}
}

impl Unify for ValueType<TypeVar> {
type Resolved = ValueType<ResolvedTypeVarsOnly>;

Expand Down
17 changes: 14 additions & 3 deletions src/unification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,23 @@ impl UnificationTable {
}
}

/// Interface for values supporting unification.
/// Interface for types supporting unification.
pub trait Unify: Sized {
/// The type after unification.
/// The type after unification. This is the concrete type (with no type
/// variables) corresponding to `Self`.
type Resolved;

/// Unify two values.
/// Unify two types, updating any type variables in `self` to be a type
/// consistent with `other`. This may involve binding a pattern variable
/// `?T` to a concrete type, or "loosening" an existing binding like `?T:
/// INT64` to `?T: FLOAT64` so that it can hold all the types we've seen.
///
/// If a type variable is already bound to a type like `?T: INT64`, and
/// we're asked to unify it with an incompatible type `STRING`, we'll return
/// an error.
///
/// This is what allows us to deduce that `ARRAY[1, 2.0, NULL]` is an
/// `ARRAY<FLOAT64>`.
fn unify(
&self,
other: &Self::Resolved,
Expand Down
15 changes: 15 additions & 0 deletions tests/sql/queries/from/unnest_struct.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- pending: snowflake Shouldn't be hard
-- pending: sqlite3 No array support
-- pending: trino Needs investigation
--
-- FROM UNNEST

CREATE OR REPLACE TABLE __result1 AS
SELECT * FROM UNNEST([STRUCT<a INT64, b INT64>(1, 2)]);

CREATE OR REPLACE TABLE __expected1 (
a INT64,
b INT64,
);
INSERT INTO __expected1 VALUES
(1, 2);

0 comments on commit 3bde535

Please sign in to comment.