Skip to content

Commit

Permalink
Infer JOIN types
Browse files Browse the repository at this point in the history
Co-authored-by: Dave Shirley <[email protected]>
  • Loading branch information
emk and dave-shirley-faraday committed Nov 2, 2023
1 parent 3ff6ce7 commit df88ee8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,7 @@ pub enum JoinOperation {
impl JoinOperation {
/// The FROM item on the right-hand side of this join.
#[allow(clippy::wrong_self_convention)]
pub fn from_item(&self) -> &FromItem {
pub fn from_item_mut(&mut self) -> &mut FromItem {
match self {
JoinOperation::ConditionJoin { from_item, .. } => from_item,
JoinOperation::CrossJoin { from_item, .. } => from_item,
Expand Down
47 changes: 36 additions & 11 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use std::collections::HashSet;

use crate::{
ast::{self, Name},
ast::{self, ConditionJoinOperator, Name},
errors::{Error, Result},
scope::{ColumnSet, ColumnSetScope, Scope, ScopeGet, ScopeHandle},
tokenizer::{Ident, Literal, LiteralValue, Spanned},
Expand Down Expand Up @@ -33,7 +33,7 @@ use crate::{

/// Types which support inference.
pub trait InferTypes {
type Scope: ScopeGet;
type Scope;

/// The result of type inference. For expressions, this will be a type.
/// For top-level statements, this will be a new scope.
Expand Down Expand Up @@ -362,15 +362,10 @@ impl InferTypes for ast::FromClause {
type Scope = ScopeHandle;
type Output = ColumnSetScope;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
let ast::FromClause {
from_item,
join_operations,
..
} = self;
let scope = from_item.infer_types(scope)?;
if !join_operations.is_empty() {
return Err(nyi(self, "join operations"));
fn infer_types(&mut self, outer_scope: &Self::Scope) -> Result<Self::Output> {
let mut scope = self.from_item.infer_types(outer_scope)?;
for op in &mut self.join_operations {
scope = op.infer_types(&(outer_scope.clone(), scope))?;
}
Ok(scope)
}
Expand Down Expand Up @@ -398,6 +393,36 @@ impl InferTypes for ast::FromItem {
}
}

impl InferTypes for ast::JoinOperation {
type Scope = (ScopeHandle, ColumnSetScope);
type Output = ColumnSetScope;

fn infer_types(&mut self, (outer_scope, scope): &Self::Scope) -> Result<Self::Output> {
let from_type = self.from_item_mut().infer_types(outer_scope)?;
scope.clone().try_transform(|column_set| match self {
ast::JoinOperation::ConditionJoin {
operator: Some(ConditionJoinOperator::Using { column_names, .. }, ..),
..
} => {
let column_names = column_names
.node_iter()
.map(|ident| ident.clone().into())
.collect::<Vec<_>>();
column_set.join_using(from_type.column_set(), &column_names)
}
ast::JoinOperation::ConditionJoin {
operator: Some(ConditionJoinOperator::On { expression, .. }, ..),
..
} => {
let expr_ty = expression.infer_types(scope)?;
expr_ty.expect_subtype_of(&ArgumentType::bool(), expression)?;
Ok(column_set.join(from_type.column_set()))
}
_ => Ok(column_set.join(from_type.column_set())),
})
}
}

impl InferTypes for ast::Expression {
type Scope = ColumnSetScope;
type Output = ArgumentType;
Expand Down
6 changes: 3 additions & 3 deletions tests/sql/functions/aggregate/nested_sum.sql
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
-- SUM(SUM(x)) is a thing. This affects the design of the type system.

create temp table t1 (
create temp table nested_sum_t1 (
g1 STRING,
g2 STRING,
x INT64
);

insert into t1 values
insert into nested_sum_t1 values
('a', 'x', 1),
('a', 'y', 2),
('b', 'x', 3),
('b', 'y', 4);

CREATE OR REPLACE TABLE __result1 AS
SELECT g1, g2, SUM(SUM(x)) OVER (PARTITION BY g2) AS `sum`
FROM t1
FROM nested_sum_t1
GROUP BY g1, g2;

CREATE OR REPLACE TABLE __expected1 (
Expand Down
14 changes: 7 additions & 7 deletions tests/sql/queries/joins/cross.sql
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
-- CROSS JOIN

CREATE TEMP TABLE t1 (a INT64);
INSERT INTO t1 VALUES (1), (2);
CREATE TEMP TABLE cj_t1 (a INT64);
INSERT INTO cj_t1 VALUES (1), (2);

CREATE TEMP TABLE t2 (b INT64);
INSERT INTO t2 VALUES (3), (4);
CREATE TEMP TABLE cj_t2 (b INT64);
INSERT INTO cj_t2 VALUES (3), (4);

CREATE OR REPLACE TABLE __result1 AS
SELECT t1.a, t2.b
FROM t1
CROSS JOIN t2;
SELECT cj_t1.a, cj_t2.b
FROM cj_t1
CROSS JOIN cj_t2;

CREATE OR REPLACE TABLE __expected1 (
a INT64,
Expand Down

0 comments on commit df88ee8

Please sign in to comment.