Skip to content

Commit

Permalink
Infer basic window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
emk committed Nov 6, 2023
1 parent 48ead2c commit fa5d5d2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 18 deletions.
80 changes: 65 additions & 15 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ impl InferTypes for ast::IsExpression {
self.left.infer_types(scope)?,
self.predicate.infer_types(scope)?,
];
func_ty.return_type_for(&arg_types, func_name)
func_ty.return_type_for(&arg_types, false, func_name)
}
}

Expand Down Expand Up @@ -631,7 +631,7 @@ impl InferTypes for ast::InExpression {
let left_ty = self.left.infer_types(scope)?;
let value_set_ty = self.value_set.infer_types(scope)?;
let elem_ty = value_set_ty.expect_one_column(&self.value_set)?.ty.clone();
func_ty.return_type_for(&[left_ty, elem_ty], func_name)
func_ty.return_type_for(&[left_ty, elem_ty], false, func_name)
}
}

Expand Down Expand Up @@ -685,7 +685,7 @@ impl InferTypes for ast::BetweenExpression {
self.middle.as_mut(),
self.right.as_mut(),
];
infer_call(func_name, args, scope)
infer_call(func_name, args, false, scope)
}
}

Expand All @@ -699,7 +699,7 @@ impl InferTypes for ast::KeywordBinopExpression {
self.op_keyword.span(),
);
let args = [self.left.as_mut(), self.right.as_mut()];
infer_call(func_name, args, scope)
infer_call(func_name, args, false, scope)
}
}

Expand All @@ -710,7 +710,7 @@ impl InferTypes for ast::NotExpression {
fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
let func_name = &Name::new("%NOT", self.not_token.span());
let args = [self.expression.as_mut()];
infer_call(func_name, args, scope)
infer_call(func_name, args, false, scope)
}
}

Expand All @@ -724,7 +724,7 @@ impl InferTypes for ast::IfExpression {
self.then_expression.as_mut(),
self.else_expression.as_mut(),
];
infer_call(&Name::new("%IF", self.if_token.span()), args, scope)
infer_call(&Name::new("%IF", self.if_token.span()), args, false, scope)
}
}

Expand Down Expand Up @@ -780,7 +780,12 @@ impl InferTypes for ast::BinopExpression {
&format!("%{}", self.op_token.token.as_str()),
self.op_token.span(),
);
infer_call(&prim_name, [self.left.as_mut(), self.right.as_mut()], scope)
infer_call(
&prim_name,
[self.left.as_mut(), self.right.as_mut()],
false,
scope,
)
}
}

Expand Down Expand Up @@ -819,7 +824,7 @@ impl InferTypes for ast::ArrayDefinition {
// We can use infer_call if we're careful.
let span = exprs.items.span();
let func_name = &Name::new("%ARRAY", span);
let elem_ty = infer_call(func_name, exprs.node_iter_mut(), scope)?;
let elem_ty = infer_call(func_name, exprs.node_iter_mut(), false, scope)?;
let elem_ty = elem_ty.expect_array_type_returning_elem_type(self)?;
let elem_ty = ArgumentType::Value(ValueType::Simple(elem_ty.clone()));
Ok(elem_ty)
Expand All @@ -842,9 +847,10 @@ impl InferTypes for ast::CountExpression {
expression,
..
} => {
// TODO: COUNT OVER
let func_name = &Name::new("COUNT", count_token.span());
let args = [expression.as_mut()];
infer_call(func_name, args, scope)
infer_call(func_name, args, false, scope)
}
}
}
Expand All @@ -855,13 +861,14 @@ impl InferTypes for ast::ArrayAggExpression {
type Output = ArgumentType;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
// TODO: ARRAY_AGG OVER
if let Some(order_by) = &mut self.order_by {
// TODO: Should this always return an aggregate type?
order_by.infer_types(scope)?;
}
let func_name = &Name::new("ARRAY_AGG", self.array_agg_token.span());
let args = [self.expression.as_mut()];
infer_call(func_name, args, scope)
infer_call(func_name, args, false, scope)
}
}

Expand All @@ -882,10 +889,52 @@ impl InferTypes for ast::FunctionCall {
type Output = ArgumentType;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
if self.over_clause.is_some() {
return Err(nyi(&self.over_clause, "over clause"));
if let Some(over_clause) = self.over_clause.as_mut() {
let new_scope = over_clause.infer_types(scope)?;
infer_call(&self.name, self.args.node_iter_mut(), true, &new_scope)
} else {
infer_call(&self.name, self.args.node_iter_mut(), false, scope)
}
}
}

impl InferTypes for ast::OverClause {
type Scope = ColumnSetScope;
type Output = ColumnSetScope;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
if let Some(order_by) = &mut self.order_by {
order_by.infer_types(scope)?;
}
let partition_by_names = if let Some(partition_by) = &mut self.partition_by {
partition_by.infer_types(scope)?
} else {
vec![]
};
scope
.clone()
.try_transform(|column_set| column_set.group_by(&partition_by_names))
}
}

impl InferTypes for ast::PartitionBy {
type Scope = ColumnSetScope;
type Output = Vec<Name>;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
let mut partition_by_names = vec![];
for expr in self.expressions.node_iter_mut() {
match expr {
ast::Expression::ColumnName(name) => {
scope.get_argument_type(name)?;
partition_by_names.push(name.clone());
}
_ => {
return Err(nyi(expr, "partition by expression"));
}
}
}
infer_call(&self.name, self.args.node_iter_mut(), scope)
Ok(partition_by_names)
}
}

Expand All @@ -901,7 +950,7 @@ impl InferTypes for ast::IndexExpression {
| ast::IndexOffset::Ordinal { expression, .. } => expression,
};
let args = [self.expression.as_mut(), index_expr];
infer_call(func_name, args, scope)
infer_call(func_name, args, false, scope)
}
}

Expand Down Expand Up @@ -953,6 +1002,7 @@ fn except_set(except: &Option<ast::Except>) -> HashSet<Name> {
fn infer_call<'args, ArgExprs>(
func_name: &Name,
args: ArgExprs,
is_window: bool,
scope: &ColumnSetScope,
) -> Result<ArgumentType>
where
Expand All @@ -965,7 +1015,7 @@ where
arg_types.push(arg.infer_types(scope)?);
}
trace!(name = &func_name.unescaped_bigquery(), args = ?arg_types, "inferring function call");
let ret_ty = func_ty.return_type_for(&arg_types, func_name)?;
let ret_ty = func_ty.return_type_for(&arg_types, is_window, func_name)?;
Ok(ret_ty)
}

Expand Down
10 changes: 9 additions & 1 deletion src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -940,10 +940,11 @@ impl FunctionType {
pub fn return_type_for(
&self,
arg_types: &[ArgumentType],
is_window: bool,
spanned: &dyn Spanned,
) -> Result<ArgumentType> {
for sig in &self.signatures {
if let Some(return_type) = sig.return_type_for(arg_types, spanned)? {
if let Some(return_type) = sig.return_type_for(arg_types, is_window, spanned)? {
return Ok(return_type);
}
}
Expand Down Expand Up @@ -1025,11 +1026,18 @@ impl FunctionSignature {
pub fn return_type_for(
&self,
arg_types: &[ArgumentType],
is_window: bool,
spanned: &dyn Spanned,
) -> Result<Option<ArgumentType>> {
if self.params.len() > arg_types.len() {
return Ok(None);
}

// Window functions can only be matched in a window context.
if is_window != (self.sig_type == FunctionSignatureType::Window) {
return Ok(None);
}

let mut table = UnificationTable::default();
for tv in &self.type_vars {
table.declare(tv.clone(), spanned)?;
Expand Down
4 changes: 2 additions & 2 deletions tests/sql/functions/windows/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

TODO:

- [ ] What would `SUM(x) + y OVER ()` do to the type of `y`? This isn't
valid BigQuery SQL, so maybe we can ignore it.
- [ ] `COUNT(..) OVER (..)`
- [ ] `ARRAY_AGG(..) OVER (..)`

0 comments on commit fa5d5d2

Please sign in to comment.