Skip to content

Commit

Permalink
implement DisjointGroupMap
Browse files Browse the repository at this point in the history
Signed-off-by: Yuchen Liang <[email protected]>
  • Loading branch information
yliang412 committed Oct 14, 2024
1 parent fd1f49e commit 4129dff
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 3 deletions.
9 changes: 8 additions & 1 deletion optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use std::{
};

use anyhow::{bail, Result};
use disjoint_group::set::{DisjointSet, UnionFind};
use disjoint_group::{
set::{DisjointSet, UnionFind},
DisjointGroupMap,
};
use itertools::Itertools;
use std::any::Any;

Expand Down Expand Up @@ -195,13 +198,15 @@ impl<T: RelNodeTyp> Memo<T> {
props
}

// TODO(yuchen): make internal to disjoint group
fn clear_exprs_in_group(&mut self, group_id: GroupId) {
self.groups.remove(&group_id);
}

/// If group_id exists, it adds expr_id to the existing group
/// Otherwise, it creates a new group of that group_id and insert expr_id into the new group
fn add_expr_to_group(&mut self, expr_id: ExprId, group_id: GroupId, memo_node: RelMemoNode<T>) {
// TODO(yuchen): use entry API
if let Entry::Occupied(mut entry) = self.groups.entry(group_id) {
let group = entry.get_mut();
group.group_exprs.insert(expr_id);
Expand All @@ -214,6 +219,7 @@ impl<T: RelNodeTyp> Memo<T> {
};
group.group_exprs.insert(expr_id);
self.disjoint_group_ids.add(group_id);
// TODO(yuchen): use insert
self.groups.insert(group_id, group);
}

Expand All @@ -228,6 +234,7 @@ impl<T: RelNodeTyp> Memo<T> {
) -> bool {
let replace_group_id = self.get_reduced_group_id(replace_group_id);

// TODO(yuchen): use disjoint group entry API
if let Entry::Occupied(mut entry) = self.groups.entry(replace_group_id) {
let group = entry.get_mut();
if !group.group_exprs.contains(&expr_id) {
Expand Down
110 changes: 110 additions & 0 deletions optd-core/src/cascades/memo/disjoint_group.rs
Original file line number Diff line number Diff line change
@@ -1 +1,111 @@
use std::{
collections::{hash_map, HashMap},
ops::Index,
};

use set::{DisjointSet, UnionFind};

use crate::cascades::GroupId;

use super::Group;

pub mod set;

const MISMATCH_ERROR: &str = "`groups` and `id_map` report unmatched group membership";

pub(crate) struct DisjointGroupMap {
id_map: DisjointSet<GroupId>,
groups: HashMap<GroupId, Group>,
}

impl DisjointGroupMap {
/// Creates a new disjoint group instance.
pub fn new() -> Self {
DisjointGroupMap {
id_map: DisjointSet::new(),
groups: HashMap::new(),
}
}

pub fn get(&self, id: &GroupId) -> Option<&Group> {
self.id_map
.find(id)
.map(|rep| self.groups.get(&rep).expect(MISMATCH_ERROR))
}

pub fn get_mut(&mut self, id: &GroupId) -> Option<&mut Group> {
self.id_map
.find(id)
.map(|rep| self.groups.get_mut(&rep).expect(MISMATCH_ERROR))
}

unsafe fn insert_new(&mut self, id: GroupId, group: Group) {
self.id_map.add(id);
self.groups.insert(id, group);
}

pub fn entry(&mut self, id: GroupId) -> GroupEntry<'_> {
use hash_map::Entry::*;
let rep = self.id_map.find(&id).unwrap_or(id);
let id_entry = self.id_map.entry(rep);
let group_entry = self.groups.entry(rep);
match (id_entry, group_entry) {
(Occupied(_), Occupied(inner)) => GroupEntry::Occupied(OccupiedGroupEntry { inner }),
(Vacant(id), Vacant(inner)) => GroupEntry::Vacant(VacantGroupEntry { id, inner }),
_ => unreachable!("{MISMATCH_ERROR}"),
}
}
}

pub enum GroupEntry<'a> {
Occupied(OccupiedGroupEntry<'a>),
Vacant(VacantGroupEntry<'a>),
}

pub struct OccupiedGroupEntry<'a> {
inner: hash_map::OccupiedEntry<'a, GroupId, Group>,
}

pub struct VacantGroupEntry<'a> {
id: hash_map::VacantEntry<'a, GroupId, GroupId>,
inner: hash_map::VacantEntry<'a, GroupId, Group>,
}

impl<'a> OccupiedGroupEntry<'a> {
pub fn id(&self) -> &GroupId {
self.inner.key()
}

pub fn get(&self) -> &Group {
self.inner.get()
}

pub fn get_mut(&mut self) -> &mut Group {
self.inner.get_mut()
}
}

impl<'a> VacantGroupEntry<'a> {
pub fn id(&self) -> &GroupId {
self.inner.key()
}

pub fn insert(self, group: Group) -> &'a mut Group {
let id = *self.id();
self.id.insert(id);
self.inner.insert(group)
}
}

impl Index<&GroupId> for DisjointGroupMap {
type Output = Group;

fn index(&self, index: &GroupId) -> &Self::Output {
let rep = self
.id_map
.find(index)
.expect("no group found for group id");

self.groups.get(&rep).expect(MISMATCH_ERROR)
}
}
13 changes: 11 additions & 2 deletions optd-core/src/cascades/memo/disjoint_group/set.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::{collections::HashMap, fmt::Debug, hash::Hash};
use std::{
collections::{hash_map, HashMap},
fmt::Debug,
hash::Hash,
};

/// A data structure for efficiently maintaining disjoint sets of `T`.
pub struct DisjointSet<T> {
Expand Down Expand Up @@ -37,7 +41,7 @@ where

impl<T> DisjointSet<T>
where
T: Ord + Hash + Copy,
T: Ord + Hash + Copy + Debug,
{
pub fn new() -> Self {
DisjointSet {
Expand Down Expand Up @@ -99,6 +103,11 @@ where

Some(parent)
}

/// Gets the given node corresponding entry in `node_parents` map for in-place manipulation.
pub(super) fn entry(&mut self, node: T) -> hash_map::Entry<'_, T, T> {
self.node_parents.entry(node)
}
}

impl<T> UnionFind<T> for DisjointSet<T>
Expand Down

0 comments on commit 4129dff

Please sign in to comment.