Skip to content

Commit

Permalink
Infer more parts of SELECT
Browse files Browse the repository at this point in the history
  • Loading branch information
emk committed Oct 30, 2023
1 parent d0d6fdf commit ac2337b
Showing 1 changed file with 66 additions and 11 deletions.
77 changes: 66 additions & 11 deletions src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Our type inference subsystem.
use std::collections::HashSet;

use crate::{
ast,
errors::{Error, Result},
Expand Down Expand Up @@ -297,11 +299,34 @@ impl InferTypes for ast::SelectExpression {
..
} = self;

// See if we have a FROM clause.
let mut from_type = None;
let mut scope = scope.to_owned();
if let Some(from_clause) = from_clause {
((), scope) = from_clause.infer_types(&scope)?;
let (new_from_type, new_scope) = from_clause.infer_types(&scope)?;
from_type = Some(new_from_type);
scope = new_scope;
}

// Helper function to add columns from a table type to a list of columns.
let add_table_cols =
|cols: &mut Vec<_>, table_type: &TableType, except: &Option<ast::Except>| {
let except = except_set(except);
for column in &table_type.columns {
if let Some(column_name) = &column.name {
let column_name = CaseInsensitiveIdent::from(column_name.clone());
if !except.contains(&column_name) {
cols.push(ColumnType {
name: column.name.clone(),
ty: column.ty.to_owned(),
not_null: false,
});
}
}
}
};

// Iterate over the select list, adding columns to the scope.
let mut cols = vec![];
for item in select_list.node_iter_mut() {
match item {
Expand All @@ -320,7 +345,24 @@ impl InferTypes for ast::SelectExpression {
not_null: false,
});
}
_ => return Err(nyi(item, "select list item")),
ast::SelectListItem::Wildcard { star, except } => {
if let Some(from_type) = &from_type {
add_table_cols(&mut cols, from_type, except);
} else {
return Err(Error::annotated(
"cannot use * in SELECT without a FROM clause",
star.span(),
"no FROM clause",
));
}
}
ast::SelectListItem::TableNameWildcard {
table_name, except, ..
} => {
let table = ident_from_table_name(table_name)?;
let table_type = scope.get_or_err(&table)?.try_as_table_type(&table)?;
add_table_cols(&mut cols, table_type, except);
}
}
}
let table_type = TableType { columns: cols };
Expand All @@ -329,36 +371,38 @@ impl InferTypes for ast::SelectExpression {
}

impl InferTypes for ast::FromClause {
type Type = ();
type Type = TableType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let ast::FromClause {
from_item,
join_operations,
..
} = self;
let ((), scope) = from_item.infer_types(scope)?;
let (table_type, scope) = from_item.infer_types(scope)?;
if !join_operations.is_empty() {
return Err(nyi(self, "join operations"));
}
Ok(((), scope))
Ok((table_type, scope))
}
}

impl InferTypes for ast::FromItem {
type Type = ();
/// We return a table type for use by `SELECT *`.
type Type = TableType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
match self {
ast::FromItem::TableName { table_name, alias } => {
let table = ident_from_table_name(table_name)?;
let table_type = scope.get_or_err(&table)?.try_as_table_type(&table)?;

if alias.is_some() {
return Err(nyi(alias, "from with alias"));
}
let name = match alias {
Some(alias) => CaseInsensitiveIdent::from(alias.ident.clone()),
None => table,
};

let mut scope = Scope::new(scope);
scope.add(name, Type::Table(table_type.clone()))?;
for column in &table_type.columns {
if let Some(column_name) = &column.name {
scope.add(
Expand All @@ -367,7 +411,7 @@ impl InferTypes for ast::FromItem {
)?;
}
}
Ok(((), scope.into_handle()))
Ok((table_type.clone(), scope.into_handle()))
}
ast::FromItem::Subquery { .. } => Err(nyi(self, "from subquery")),
ast::FromItem::Unnest { .. } => Err(nyi(self, "from unnest")),
Expand Down Expand Up @@ -782,6 +826,17 @@ fn ident_from_function_name(function_name: &ast::FunctionName) -> Result<CaseIns
}
}

/// Build a set from an optional [`ast::Except`] clause.
fn except_set(except: &Option<ast::Except>) -> HashSet<CaseInsensitiveIdent> {
let mut set = HashSet::new();
if let Some(except) = except {
for ident in except.columns.node_iter() {
set.insert(ident.clone().into());
}
}
set
}

/// Infer types a function-like expression (including primitives).
fn infer_call<'args, ArgExprs>(
func_name: &CaseInsensitiveIdent,
Expand Down

0 comments on commit ac2337b

Please sign in to comment.