diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index 62ad52b7..c4bbfb0a 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -6,7 +6,10 @@ use std::{ }; use anyhow::{bail, Result}; -use disjoint_group::set::{DisjointSet, UnionFind}; +use disjoint_group::{ + set::{DisjointSet, UnionFind}, + DisjointGroupMap, +}; use itertools::Itertools; use std::any::Any; @@ -195,6 +198,7 @@ impl Memo { props } + // TODO(yuchen): make internal to disjoint group fn clear_exprs_in_group(&mut self, group_id: GroupId) { self.groups.remove(&group_id); } @@ -202,6 +206,7 @@ impl Memo { /// If group_id exists, it adds expr_id to the existing group /// Otherwise, it creates a new group of that group_id and insert expr_id into the new group fn add_expr_to_group(&mut self, expr_id: ExprId, group_id: GroupId, memo_node: RelMemoNode) { + // TODO(yuchen): use entry API if let Entry::Occupied(mut entry) = self.groups.entry(group_id) { let group = entry.get_mut(); group.group_exprs.insert(expr_id); @@ -214,6 +219,7 @@ impl Memo { }; group.group_exprs.insert(expr_id); self.disjoint_group_ids.add(group_id); + // TODO(yuchen): use insert self.groups.insert(group_id, group); } @@ -228,6 +234,7 @@ impl Memo { ) -> bool { let replace_group_id = self.get_reduced_group_id(replace_group_id); + // TODO(yuchen): use disjoint group entry API if let Entry::Occupied(mut entry) = self.groups.entry(replace_group_id) { let group = entry.get_mut(); if !group.group_exprs.contains(&expr_id) { diff --git a/optd-core/src/cascades/memo/disjoint_group.rs b/optd-core/src/cascades/memo/disjoint_group.rs index d7e52559..57c22567 100644 --- a/optd-core/src/cascades/memo/disjoint_group.rs +++ b/optd-core/src/cascades/memo/disjoint_group.rs @@ -1 +1,111 @@ +use std::{ + collections::{hash_map, HashMap}, + ops::Index, +}; + +use set::{DisjointSet, UnionFind}; + +use crate::cascades::GroupId; + +use super::Group; + pub mod set; + +const MISMATCH_ERROR: &str = "`groups` and `id_map` report unmatched group membership"; + +pub(crate) struct DisjointGroupMap { + id_map: DisjointSet, + groups: HashMap, +} + +impl DisjointGroupMap { + /// Creates a new disjoint group instance. + pub fn new() -> Self { + DisjointGroupMap { + id_map: DisjointSet::new(), + groups: HashMap::new(), + } + } + + pub fn get(&self, id: &GroupId) -> Option<&Group> { + self.id_map + .find(id) + .map(|rep| self.groups.get(&rep).expect(MISMATCH_ERROR)) + } + + pub fn get_mut(&mut self, id: &GroupId) -> Option<&mut Group> { + self.id_map + .find(id) + .map(|rep| self.groups.get_mut(&rep).expect(MISMATCH_ERROR)) + } + + unsafe fn insert_new(&mut self, id: GroupId, group: Group) { + self.id_map.add(id); + self.groups.insert(id, group); + } + + pub fn entry(&mut self, id: GroupId) -> GroupEntry<'_> { + use hash_map::Entry::*; + let rep = self.id_map.find(&id).unwrap_or(id); + let id_entry = self.id_map.entry(rep); + let group_entry = self.groups.entry(rep); + match (id_entry, group_entry) { + (Occupied(_), Occupied(inner)) => GroupEntry::Occupied(OccupiedGroupEntry { inner }), + (Vacant(id), Vacant(inner)) => GroupEntry::Vacant(VacantGroupEntry { id, inner }), + _ => unreachable!("{MISMATCH_ERROR}"), + } + } +} + +pub enum GroupEntry<'a> { + Occupied(OccupiedGroupEntry<'a>), + Vacant(VacantGroupEntry<'a>), +} + +pub struct OccupiedGroupEntry<'a> { + inner: hash_map::OccupiedEntry<'a, GroupId, Group>, +} + +pub struct VacantGroupEntry<'a> { + id: hash_map::VacantEntry<'a, GroupId, GroupId>, + inner: hash_map::VacantEntry<'a, GroupId, Group>, +} + +impl<'a> OccupiedGroupEntry<'a> { + pub fn id(&self) -> &GroupId { + self.inner.key() + } + + pub fn get(&self) -> &Group { + self.inner.get() + } + + pub fn get_mut(&mut self) -> &mut Group { + self.inner.get_mut() + } +} + +impl<'a> VacantGroupEntry<'a> { + pub fn id(&self) -> &GroupId { + self.inner.key() + } + + pub fn insert(self, group: Group) -> &'a mut Group { + let id = *self.id(); + self.id.insert(id); + self.inner.insert(group) + } +} + +impl Index<&GroupId> for DisjointGroupMap { + type Output = Group; + + fn index(&self, index: &GroupId) -> &Self::Output { + let rep = self + .id_map + .find(index) + .expect("no group found for group id"); + + self.groups.get(&rep).expect(MISMATCH_ERROR) + } +} diff --git a/optd-core/src/cascades/memo/disjoint_group/set.rs b/optd-core/src/cascades/memo/disjoint_group/set.rs index 35cc6b81..c343e59e 100644 --- a/optd-core/src/cascades/memo/disjoint_group/set.rs +++ b/optd-core/src/cascades/memo/disjoint_group/set.rs @@ -1,4 +1,8 @@ -use std::{collections::HashMap, fmt::Debug, hash::Hash}; +use std::{ + collections::{hash_map, HashMap}, + fmt::Debug, + hash::Hash, +}; /// A data structure for efficiently maintaining disjoint sets of `T`. pub struct DisjointSet { @@ -37,7 +41,7 @@ where impl DisjointSet where - T: Ord + Hash + Copy, + T: Ord + Hash + Copy + Debug, { pub fn new() -> Self { DisjointSet { @@ -99,6 +103,11 @@ where Some(parent) } + + /// Gets the given node corresponding entry in `node_parents` map for in-place manipulation. + pub(super) fn entry(&mut self, node: T) -> hash_map::Entry<'_, T, T> { + self.node_parents.entry(node) + } } impl UnionFind for DisjointSet