diff --git a/src/types.rs b/src/types.rs index 1f54e41..366d4d1 100644 --- a/src/types.rs +++ b/src/types.rs @@ -171,8 +171,8 @@ impl fmt::Display for Type { pub enum ArgumentType { /// A value type. Value(ValueType), - /// An aggregating value type. - Aggregating(ValueType), + /// An aggregating value type. Note that we can nest aggregating types. + Aggregating(Box>), } impl ArgumentType { @@ -270,7 +270,7 @@ impl ArgumentType { } (ArgumentType::Aggregating(a), ArgumentType::Aggregating(b)) => { - Some(ArgumentType::Aggregating(a.common_supertype(b)?)) + Some(ArgumentType::Aggregating(Box::new(a.common_supertype(b)?))) } _ => None, } @@ -290,9 +290,9 @@ impl Unify for ArgumentType { (ArgumentType::Value(a), ArgumentType::Value(b)) => { Ok(ArgumentType::Value(a.unify(b, table, spanned)?)) } - (ArgumentType::Aggregating(a), ArgumentType::Aggregating(b)) => { - Ok(ArgumentType::Aggregating(a.unify(b, table, spanned)?)) - } + (ArgumentType::Aggregating(a), ArgumentType::Aggregating(b)) => Ok( + ArgumentType::Aggregating(Box::new(a.unify(b, table, spanned)?)), + ), _ => Err(Error::annotated( format!("cannot unify {} and {}", self, other), spanned.span(), @@ -304,9 +304,9 @@ impl Unify for ArgumentType { fn resolve(&self, table: &UnificationTable, spanned: &dyn Spanned) -> Result { match self { ArgumentType::Value(t) => Ok(ArgumentType::Value(t.resolve(table, spanned)?)), - ArgumentType::Aggregating(t) => { - Ok(ArgumentType::Aggregating(t.resolve(table, spanned)?)) - } + ArgumentType::Aggregating(t) => Ok(ArgumentType::Aggregating(Box::new( + t.resolve(table, spanned)?, + ))), } } } @@ -1172,7 +1172,9 @@ peg::parser! { / t:function_type() { Type::Function(t) } rule argument_type() -> ArgumentType - = "Agg" _? "<" _? t:value_type() _? ">" { ArgumentType::Aggregating(t) } + = "Agg" _? "<" _? t:argument_type() _? ">" { + ArgumentType::Aggregating(Box::new(t)) + } / t:value_type() { ArgumentType::Value(t) } rule value_type() -> ValueType @@ -1341,4 +1343,16 @@ pub mod tests { Type::Argument(ArgumentType::Value(ValueType::Simple(SimpleType::Bool))) ); } + + #[test] + fn parse_nested_agg() { + assert_eq!( + ty("Agg>"), + Type::Argument(ArgumentType::Aggregating(Box::new( + ArgumentType::Aggregating(Box::new(ArgumentType::Value(ValueType::Simple( + SimpleType::Int64 + )))) + ))) + ); + } } diff --git a/tests/sql/functions/aggregate/nested_sum.sql b/tests/sql/functions/aggregate/nested_sum.sql new file mode 100644 index 0000000..63d34fa --- /dev/null +++ b/tests/sql/functions/aggregate/nested_sum.sql @@ -0,0 +1,29 @@ +-- SUM(SUM(x)) is a thing. This affects the design of the type system. + +create temp table t1 ( + g1 STRING, + g2 STRING, + x INT64 +); + +insert into t1 values + ('a', 'x', 1), + ('a', 'y', 2), + ('b', 'x', 3), + ('b', 'y', 4); + +CREATE OR REPLACE TABLE __result1 AS +SELECT g1, g2, SUM(SUM(x)) OVER (PARTITION BY g2) AS `sum` +FROM t1 +GROUP BY g1, g2; + +CREATE OR REPLACE TABLE __expected1 ( + g1 STRING, + g2 STRING, + `sum` INT64 +); +INSERT INTO __expected1 VALUES + ('a', 'x', 4), + ('a', 'y', 6), + ('b', 'x', 4), + ('b', 'y', 6);