diff --git a/src/types.rs b/src/types.rs index 1e37ba3..ded46e6 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1042,8 +1042,20 @@ impl FunctionSignature { for tv in &self.type_vars { table.declare(tv.clone(), spanned)?; } + let mut lift_to_aggregate = false; for (i, param_ty) in self.params.iter().enumerate() { + // Try to unify normally. if param_ty.unify(&arg_types[i], &mut table, spanned).is_err() { + // We failed, but let's see if we can lift a scalar function to + // an aggregate function by adjusting the return type. + if let ArgumentType::Aggregating(arg_ty) = &arg_types[i] { + if param_ty.unify(arg_ty.as_ref(), &mut table, spanned).is_ok() { + lift_to_aggregate = true; + continue; + } + } + + // We can't match this parameter, so fail. return Ok(None); } } @@ -1057,9 +1069,12 @@ impl FunctionSignature { } else if self.params.len() < arg_types.len() { return Ok(None); } - self.return_type - .resolve(&table, spanned) - .map(|ty| Some(ArgumentType::Value(ty))) + let return_ty = ArgumentType::Value(self.return_type.resolve(&table, spanned)?); + if lift_to_aggregate { + Ok(Some(ArgumentType::Aggregating(Box::new(return_ty)))) + } else { + Ok(Some(return_ty)) + } } }