From bc17ae90837980480ef35b716c208d0f7c871f7f Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Mon, 14 Oct 2024 10:51:45 -0400 Subject: [PATCH] core: use union-find to merge groups Also improve tpch sqlplannertest run from ~89s to ~9s. Signed-off-by: Yuchen Liang --- optd-core/src/cascades/memo.rs | 38 +-- optd-core/src/cascades/memo/disjoint_set.rs | 243 ++++++++++++++++++++ optd-sqlplannertest/tests/tpch.planner.sql | 116 +++++----- 3 files changed, 321 insertions(+), 76 deletions(-) create mode 100644 optd-core/src/cascades/memo/disjoint_set.rs diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index a194f71c..2503e697 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -1,3 +1,5 @@ +mod disjoint_set; + use std::{ collections::{hash_map::Entry, HashMap, HashSet}, fmt::Display, @@ -5,6 +7,7 @@ use std::{ }; use anyhow::{bail, Result}; +use disjoint_set::{DisjointSet, UnionFind}; use itertools::Itertools; use std::any::Any; @@ -78,7 +81,7 @@ pub struct Memo { expr_node_to_expr_id: HashMap, ExprId>, groups: HashMap, group_expr_counter: usize, - merged_groups: HashMap, + disjoint_groups: DisjointSet, property_builders: Arc<[Box>]>, } @@ -90,7 +93,7 @@ impl Memo { expr_node_to_expr_id: HashMap::new(), groups: HashMap::new(), group_expr_counter: 0, - merged_groups: HashMap::new(), + disjoint_groups: DisjointSet::new(), property_builders, } } @@ -99,6 +102,7 @@ impl Memo { fn next_group_id(&mut self) -> ReducedGroupId { let id = self.group_expr_counter; self.group_expr_counter += 1; + self.disjoint_groups.add(GroupId(id)); ReducedGroupId(id) } @@ -118,18 +122,20 @@ impl Memo { return group_a; } + let [rep, other] = self + .disjoint_groups + .union(&group_a.as_group_id(), &group_b.as_group_id()) + .unwrap(); + // Copy all expressions from group a to group b - let group_a_exprs = self.get_all_exprs_in_group(group_a.as_group_id()); - for expr_id in group_a_exprs { + let other_exprs = self.get_all_exprs_in_group(other); + for expr_id in other_exprs { let expr_node = self.expr_id_to_expr_node.get(&expr_id).unwrap(); - self.add_expr_to_group(expr_id, group_b, expr_node.as_ref().clone()); + self.add_expr_to_group(expr_id, ReducedGroupId(rep.0), expr_node.as_ref().clone()); } - self.merged_groups - .insert(group_a.as_group_id(), group_b.as_group_id()); - // Remove all expressions from group a (so we don't accidentally access it) - self.clear_exprs_in_group(group_a); + self.clear_exprs_in_group(ReducedGroupId(other.0)); group_b } @@ -145,11 +151,9 @@ impl Memo { self.expr_id_to_group_id[&expr_id] } - fn get_reduced_group_id(&self, mut group_id: GroupId) -> ReducedGroupId { - while let Some(next_group_id) = self.merged_groups.get(&group_id) { - group_id = *next_group_id; - } - ReducedGroupId(group_id.0) + fn get_reduced_group_id(&self, group_id: GroupId) -> ReducedGroupId { + let reduced = self.disjoint_groups.find(&group_id).unwrap(); + ReducedGroupId(reduced.0) } /// Add or get an expression into the memo, returns the group id and the expr id. If `GroupId` is `None`, @@ -164,10 +168,8 @@ impl Memo { self.merge_group(grp_a, grp_b); }; - let (group_id, expr_id) = self.add_new_group_expr_inner( - rel_node, - add_to_group_id.map(|x| self.get_reduced_group_id(x)), - ); + let add_to_group_id = add_to_group_id.map(|x| self.get_reduced_group_id(x)); + let (group_id, expr_id) = self.add_new_group_expr_inner(rel_node, add_to_group_id); (group_id.as_group_id(), expr_id) } diff --git a/optd-core/src/cascades/memo/disjoint_set.rs b/optd-core/src/cascades/memo/disjoint_set.rs new file mode 100644 index 00000000..7989fc3d --- /dev/null +++ b/optd-core/src/cascades/memo/disjoint_set.rs @@ -0,0 +1,243 @@ +use std::{ + collections::HashMap, + fmt::Debug, + hash::Hash, + ops::{Deref, DerefMut}, + sync::{atomic::AtomicUsize, Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, +}; + +/// A data structure for efficiently maintaining disjoint sets of `T`. +pub struct DisjointSet { + /// Mapping from node to its parent. + /// + /// # Design + /// We use a `mutex` instead of reader-writer lock so that + /// we always need write permission to perform `path compression` + /// during "finds". + /// + /// Alternatively, we could do no path compression at `find`, + /// and only do path compression when we were doing union. + node_parents: Arc>>, + /// Number of disjoint sets. + num_sets: AtomicUsize, +} + +pub trait UnionFind +where + T: Ord, +{ + /// Unions the set containing `a` and the set containing `b`. + /// Returns the new representative followed by the other node, + /// or `None` if one the node is not present. + /// + /// The smaller representative is selected as the new representative. + fn union(&self, a: &T, b: &T) -> Option<[T; 2]>; + + /// Gets the representative node of the set that `node` is in. + /// Path compression is performed while finding the representative. + fn find_path_compress(&self, node: &T) -> Option; + + /// Gets the representative node of the set that `node` is in. + fn find(&self, node: &T) -> Option; +} + +impl DisjointSet +where + T: Ord + Hash + Copy, +{ + pub fn new() -> Self { + DisjointSet { + node_parents: Arc::new(RwLock::new(HashMap::new())), + num_sets: AtomicUsize::new(0), + } + } + + pub fn size(&self) -> usize { + let g = self.node_parents.read().unwrap(); + g.len() + } + + pub fn num_sets(&self) -> usize { + self.num_sets.load(std::sync::atomic::Ordering::Relaxed) + } + + pub fn add(&mut self, node: T) { + use std::collections::hash_map::Entry; + + let mut g = self.node_parents.write().unwrap(); + if let Entry::Vacant(entry) = g.entry(node) { + entry.insert(node); + drop(g); + self.num_sets + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + } + + fn get_parent(g: &impl Deref>, node: &T) -> Option { + g.get(node).copied() + } + + fn set_parent(g: &mut impl DerefMut>, node: T, parent: T) { + g.insert(node, parent); + } + + /// Recursively find the parent of the `node` until reaching the representative of the set. + /// A node is the representative if the its parent is the node itself. + /// + /// We utilize "path compression" to shorten the height of parent forest. + fn find_path_compress_inner( + g: &mut RwLockWriteGuard<'_, HashMap>, + node: &T, + ) -> Option { + let mut parent = Self::get_parent(g, node)?; + + if *node != parent { + parent = Self::find_path_compress_inner(g, &parent)?; + + // Path compression. + Self::set_parent(g, *node, parent); + } + + Some(parent) + } + + /// Recursively find the parent of the `node` until reaching the representative of the set. + /// A node is the representative if the its parent is the node itself. + fn find_inner(g: &RwLockReadGuard<'_, HashMap>, node: &T) -> Option { + let mut parent = Self::get_parent(g, node)?; + + if *node != parent { + parent = Self::find_inner(g, &parent)?; + } + + Some(parent) + } +} + +impl UnionFind for DisjointSet +where + T: Ord + Hash + Copy + Debug, +{ + fn union(&self, a: &T, b: &T) -> Option<[T; 2]> { + use std::cmp::Ordering; + + // Gets the represenatives for set containing `a`. + let a_rep = self.find_path_compress(&a)?; + + // Gets the represenatives for set containing `b`. + let b_rep = self.find_path_compress(&b)?; + + let mut g = self.node_parents.write().unwrap(); + + // Node with smaller value becomes the representative. + let res = match a_rep.cmp(&b_rep) { + Ordering::Less => { + Self::set_parent(&mut g, b_rep, a_rep); + self.num_sets + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + [a_rep, b_rep] + } + Ordering::Greater => { + Self::set_parent(&mut g, a_rep, b_rep); + self.num_sets + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + [b_rep, a_rep] + } + Ordering::Equal => [a_rep, b_rep], + }; + Some(res) + } + + /// See [`Self::find_inner`] for implementation detail. + fn find_path_compress(&self, node: &T) -> Option { + let mut g = self.node_parents.write().unwrap(); + Self::find_path_compress_inner(&mut g, node) + } + + /// See [`Self::find_inner`] for implementation detail. + fn find(&self, node: &T) -> Option { + let g = self.node_parents.read().unwrap(); + Self::find_inner(&g, node) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + fn minmax(v1: T, v2: T) -> [T; 2] + where + T: Ord, + { + if v1 <= v2 { + [v1, v2] + } else { + [v2, v1] + } + } + + fn test_union_find(inputs: Vec) + where + T: Ord + Hash + Copy + Debug, + { + let mut set = DisjointSet::new(); + + for input in inputs.iter() { + set.add(*input); + } + + for input in inputs.iter() { + let rep = set.find(input); + assert_eq!( + rep, + Some(*input), + "representive should be node itself for singleton" + ); + } + assert_eq!(set.size(), 10); + assert_eq!(set.num_sets(), 10); + + for input in inputs.iter() { + set.union(input, input).unwrap(); + let rep = set.find(input); + assert_eq!( + rep, + Some(*input), + "representive should be node itself for singleton" + ); + } + assert_eq!(set.size(), 10); + assert_eq!(set.num_sets(), 10); + + for (x, y) in inputs.iter().zip(inputs.iter().rev()) { + let y_rep = set.find(&y).unwrap(); + let [rep, other] = set.union(x, y).expect(&format!( + "union should be successful between {:?} and {:?}", + x, y, + )); + if rep != other { + assert_eq!([rep, other], minmax(*x, y_rep)); + } + } + + for (x, y) in inputs.iter().zip(inputs.iter().rev()) { + let rep = set.find(x); + + let expected = x.min(y); + assert_eq!(rep, Some(*expected)); + } + assert_eq!(set.size(), 10); + assert_eq!(set.num_sets(), 5); + } + + #[test] + fn test_union_find_i32() { + test_union_find(Vec::from_iter(0..10)); + } + + #[test] + fn test_union_find_group() { + test_union_find(Vec::from_iter((0..10).map(|i| crate::cascades::GroupId(i)))); + } +} diff --git a/optd-sqlplannertest/tests/tpch.planner.sql b/optd-sqlplannertest/tests/tpch.planner.sql index 8bf88051..e7a86ef9 100644 --- a/optd-sqlplannertest/tests/tpch.planner.sql +++ b/optd-sqlplannertest/tests/tpch.planner.sql @@ -627,29 +627,29 @@ PhysicalSort │ ├── Cast { cast_to: Decimal128(20, 0), expr: 1(i64) } │ └── #23 ├── groups: [ #41 ] - └── PhysicalHashJoin { join_type: Inner, left_keys: [ #19, #3 ], right_keys: [ #0, #3 ] } - ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #1 ] } - │ ├── PhysicalScan { table: customer } - │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ ├── PhysicalFilter - │ │ ├── cond:And - │ │ │ ├── Geq - │ │ │ │ ├── #4 - │ │ │ │ └── Cast { cast_to: Date32, expr: "2023-01-01" } - │ │ │ └── Lt - │ │ │ ├── #4 - │ │ │ └── Cast { cast_to: Date32, expr: "2024-01-01" } - │ │ └── PhysicalScan { table: orders } - │ └── PhysicalScan { table: lineitem } - └── PhysicalHashJoin { join_type: Inner, left_keys: [ #9 ], right_keys: [ #0 ] } - ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } - │ ├── PhysicalScan { table: supplier } - │ └── PhysicalScan { table: nation } - └── PhysicalFilter - ├── cond:Eq - │ ├── #1 - │ └── "Asia" - └── PhysicalScan { table: region } + └── PhysicalHashJoin { join_type: Inner, left_keys: [ #42 ], right_keys: [ #0 ] } + ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #36 ], right_keys: [ #0 ] } + │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #19, #3 ], right_keys: [ #0, #3 ] } + │ │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #1 ] } + │ │ │ ├── PhysicalScan { table: customer } + │ │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } + │ │ │ ├── PhysicalFilter + │ │ │ │ ├── cond:And + │ │ │ │ │ ├── Geq + │ │ │ │ │ │ ├── #4 + │ │ │ │ │ │ └── Cast { cast_to: Date32, expr: "2023-01-01" } + │ │ │ │ │ └── Lt + │ │ │ │ │ ├── #4 + │ │ │ │ │ └── Cast { cast_to: Date32, expr: "2024-01-01" } + │ │ │ │ └── PhysicalScan { table: orders } + │ │ │ └── PhysicalScan { table: lineitem } + │ │ └── PhysicalScan { table: supplier } + │ └── PhysicalScan { table: nation } + └── PhysicalFilter + ├── cond:Eq + │ ├── #1 + │ └── "Asia" + └── PhysicalScan { table: region } */ -- TPC-H Q6 @@ -864,12 +864,12 @@ PhysicalSort ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #2 ] } │ │ ├── PhysicalScan { table: supplier } - │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #17 ], right_keys: [ #0 ] } - │ │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ │ │ ├── PhysicalFilter { cond: Between { expr: #10, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } } } - │ │ │ │ └── PhysicalScan { table: lineitem } - │ │ │ └── PhysicalScan { table: orders } - │ │ └── PhysicalScan { table: customer } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } + │ │ ├── PhysicalFilter { cond: Between { expr: #10, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } } } + │ │ │ └── PhysicalScan { table: lineitem } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] } + │ │ ├── PhysicalScan { table: orders } + │ │ └── PhysicalScan { table: customer } │ └── PhysicalScan { table: nation } └── PhysicalScan { table: nation } */ @@ -1033,14 +1033,14 @@ PhysicalSort │ │ │ │ │ └── "ECONOMY ANODIZED STEEL" │ │ │ │ └── PhysicalScan { table: part } │ │ │ └── PhysicalScan { table: supplier } - │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #17 ], right_keys: [ #0 ] } - │ │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ │ │ ├── PhysicalScan { table: lineitem } - │ │ │ └── PhysicalFilter { cond: Between { expr: #4, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } } } - │ │ │ └── PhysicalScan { table: orders } - │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } - │ │ ├── PhysicalScan { table: customer } - │ │ └── PhysicalScan { table: nation } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } + │ │ ├── PhysicalScan { table: lineitem } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] } + │ │ ├── PhysicalFilter { cond: Between { expr: #4, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } } } + │ │ │ └── PhysicalScan { table: orders } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } + │ │ ├── PhysicalScan { table: customer } + │ │ └── PhysicalScan { table: nation } │ └── PhysicalScan { table: nation } └── PhysicalFilter ├── cond:Eq @@ -1167,16 +1167,16 @@ PhysicalSort │ ├── #35 │ └── #20 └── PhysicalHashJoin { join_type: Inner, left_keys: [ #12 ], right_keys: [ #0 ] } - ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #9, #0 ], right_keys: [ #2, #1 ] } - │ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true } - │ │ ├── PhysicalFilter { cond: Like { expr: #1, pattern: "%green%", negated: false, case_insensitive: false } } - │ │ │ └── PhysicalScan { table: part } - │ │ └── PhysicalScan { table: supplier } - │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #2, #1 ], right_keys: [ #1, #0 ] } - │ │ ├── PhysicalScan { table: lineitem } - │ │ └── PhysicalScan { table: partsupp } - │ └── PhysicalScan { table: orders } + ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #16 ], right_keys: [ #0 ] } + │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #9, #0 ], right_keys: [ #2, #1 ] } + │ │ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true } + │ │ │ ├── PhysicalFilter { cond: Like { expr: #1, pattern: "%green%", negated: false, case_insensitive: false } } + │ │ │ │ └── PhysicalScan { table: part } + │ │ │ └── PhysicalScan { table: supplier } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2, #1 ], right_keys: [ #1, #0 ] } + │ │ ├── PhysicalScan { table: lineitem } + │ │ └── PhysicalScan { table: partsupp } + │ └── PhysicalScan { table: orders } └── PhysicalScan { table: nation } */ @@ -1298,16 +1298,16 @@ PhysicalSort │ ├── #35 │ └── #20 └── PhysicalHashJoin { join_type: Inner, left_keys: [ #12 ], right_keys: [ #0 ] } - ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #9, #0 ], right_keys: [ #2, #1 ] } - │ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true } - │ │ ├── PhysicalFilter { cond: Like { expr: #1, pattern: "%green%", negated: false, case_insensitive: false } } - │ │ │ └── PhysicalScan { table: part } - │ │ └── PhysicalScan { table: supplier } - │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #2, #1 ], right_keys: [ #1, #0 ] } - │ │ ├── PhysicalScan { table: lineitem } - │ │ └── PhysicalScan { table: partsupp } - │ └── PhysicalScan { table: orders } + ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #16 ], right_keys: [ #0 ] } + │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #9, #0 ], right_keys: [ #2, #1 ] } + │ │ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true } + │ │ │ ├── PhysicalFilter { cond: Like { expr: #1, pattern: "%green%", negated: false, case_insensitive: false } } + │ │ │ │ └── PhysicalScan { table: part } + │ │ │ └── PhysicalScan { table: supplier } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2, #1 ], right_keys: [ #1, #0 ] } + │ │ ├── PhysicalScan { table: lineitem } + │ │ └── PhysicalScan { table: partsupp } + │ └── PhysicalScan { table: orders } └── PhysicalScan { table: nation } */ @@ -2057,7 +2057,7 @@ PhysicalProjection ├── cond:And │ ├── Eq │ │ ├── #2 - │ │ └── #0 + │ │ └── #4 │ └── Lt │ ├── Cast { cast_to: Decimal128(30, 15), expr: #0 } │ └── #3