From df88ee8a6ff94e0167c7c650452b74dd08f8316c Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Thu, 2 Nov 2023 14:34:01 -0400 Subject: [PATCH] Infer JOIN types Co-authored-by: Dave Shirley --- src/ast.rs | 2 +- src/infer.rs | 47 +++++++++++++++----- tests/sql/functions/aggregate/nested_sum.sql | 6 +-- tests/sql/queries/joins/cross.sql | 14 +++--- 4 files changed, 47 insertions(+), 22 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 4dab572..b7ffc35 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1564,7 +1564,7 @@ pub enum JoinOperation { impl JoinOperation { /// The FROM item on the right-hand side of this join. #[allow(clippy::wrong_self_convention)] - pub fn from_item(&self) -> &FromItem { + pub fn from_item_mut(&mut self) -> &mut FromItem { match self { JoinOperation::ConditionJoin { from_item, .. } => from_item, JoinOperation::CrossJoin { from_item, .. } => from_item, diff --git a/src/infer.rs b/src/infer.rs index 3d58c49..ef3a964 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use crate::{ - ast::{self, Name}, + ast::{self, ConditionJoinOperator, Name}, errors::{Error, Result}, scope::{ColumnSet, ColumnSetScope, Scope, ScopeGet, ScopeHandle}, tokenizer::{Ident, Literal, LiteralValue, Spanned}, @@ -33,7 +33,7 @@ use crate::{ /// Types which support inference. pub trait InferTypes { - type Scope: ScopeGet; + type Scope; /// The result of type inference. For expressions, this will be a type. /// For top-level statements, this will be a new scope. @@ -362,15 +362,10 @@ impl InferTypes for ast::FromClause { type Scope = ScopeHandle; type Output = ColumnSetScope; - fn infer_types(&mut self, scope: &Self::Scope) -> Result { - let ast::FromClause { - from_item, - join_operations, - .. - } = self; - let scope = from_item.infer_types(scope)?; - if !join_operations.is_empty() { - return Err(nyi(self, "join operations")); + fn infer_types(&mut self, outer_scope: &Self::Scope) -> Result { + let mut scope = self.from_item.infer_types(outer_scope)?; + for op in &mut self.join_operations { + scope = op.infer_types(&(outer_scope.clone(), scope))?; } Ok(scope) } @@ -398,6 +393,36 @@ impl InferTypes for ast::FromItem { } } +impl InferTypes for ast::JoinOperation { + type Scope = (ScopeHandle, ColumnSetScope); + type Output = ColumnSetScope; + + fn infer_types(&mut self, (outer_scope, scope): &Self::Scope) -> Result { + let from_type = self.from_item_mut().infer_types(outer_scope)?; + scope.clone().try_transform(|column_set| match self { + ast::JoinOperation::ConditionJoin { + operator: Some(ConditionJoinOperator::Using { column_names, .. }, ..), + .. + } => { + let column_names = column_names + .node_iter() + .map(|ident| ident.clone().into()) + .collect::>(); + column_set.join_using(from_type.column_set(), &column_names) + } + ast::JoinOperation::ConditionJoin { + operator: Some(ConditionJoinOperator::On { expression, .. }, ..), + .. + } => { + let expr_ty = expression.infer_types(scope)?; + expr_ty.expect_subtype_of(&ArgumentType::bool(), expression)?; + Ok(column_set.join(from_type.column_set())) + } + _ => Ok(column_set.join(from_type.column_set())), + }) + } +} + impl InferTypes for ast::Expression { type Scope = ColumnSetScope; type Output = ArgumentType; diff --git a/tests/sql/functions/aggregate/nested_sum.sql b/tests/sql/functions/aggregate/nested_sum.sql index 63d34fa..90480a5 100644 --- a/tests/sql/functions/aggregate/nested_sum.sql +++ b/tests/sql/functions/aggregate/nested_sum.sql @@ -1,12 +1,12 @@ -- SUM(SUM(x)) is a thing. This affects the design of the type system. -create temp table t1 ( +create temp table nested_sum_t1 ( g1 STRING, g2 STRING, x INT64 ); -insert into t1 values +insert into nested_sum_t1 values ('a', 'x', 1), ('a', 'y', 2), ('b', 'x', 3), @@ -14,7 +14,7 @@ insert into t1 values CREATE OR REPLACE TABLE __result1 AS SELECT g1, g2, SUM(SUM(x)) OVER (PARTITION BY g2) AS `sum` -FROM t1 +FROM nested_sum_t1 GROUP BY g1, g2; CREATE OR REPLACE TABLE __expected1 ( diff --git a/tests/sql/queries/joins/cross.sql b/tests/sql/queries/joins/cross.sql index 47d26e8..4e47607 100644 --- a/tests/sql/queries/joins/cross.sql +++ b/tests/sql/queries/joins/cross.sql @@ -1,15 +1,15 @@ -- CROSS JOIN -CREATE TEMP TABLE t1 (a INT64); -INSERT INTO t1 VALUES (1), (2); +CREATE TEMP TABLE cj_t1 (a INT64); +INSERT INTO cj_t1 VALUES (1), (2); -CREATE TEMP TABLE t2 (b INT64); -INSERT INTO t2 VALUES (3), (4); +CREATE TEMP TABLE cj_t2 (b INT64); +INSERT INTO cj_t2 VALUES (3), (4); CREATE OR REPLACE TABLE __result1 AS -SELECT t1.a, t2.b -FROM t1 -CROSS JOIN t2; +SELECT cj_t1.a, cj_t2.b +FROM cj_t1 +CROSS JOIN cj_t2; CREATE OR REPLACE TABLE __expected1 ( a INT64,