Skip to content

Commit

Permalink
remove ReducedGroupId
Browse files Browse the repository at this point in the history
Signed-off-by: Yuchen Liang <[email protected]>
  • Loading branch information
yliang412 committed Oct 14, 2024
1 parent bc17ae9 commit 3b83135
Showing 1 changed file with 27 additions and 68 deletions.
95 changes: 27 additions & 68 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ mod disjoint_set;

use std::{
collections::{hash_map::Entry, HashMap, HashSet},
fmt::Display,
sync::Arc,
};

Expand Down Expand Up @@ -60,28 +59,15 @@ pub(crate) struct Group {
pub(crate) properties: Arc<[Box<dyn Any + Send + Sync + 'static>]>,
}

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)]
struct ReducedGroupId(usize);

impl ReducedGroupId {
pub fn as_group_id(self) -> GroupId {
GroupId(self.0)
}
}

impl Display for ReducedGroupId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

pub struct Memo<T: RelNodeTyp> {
expr_id_to_group_id: HashMap<ExprId, GroupId>,
expr_id_to_expr_node: HashMap<ExprId, RelMemoNodeRef<T>>,
expr_node_to_expr_id: HashMap<RelMemoNode<T>, ExprId>,
groups: HashMap<ReducedGroupId, Group>,
/// Stores the mapping from "representative" group id to group.
groups: HashMap<GroupId, Group>,
group_expr_counter: usize,
disjoint_groups: DisjointSet<GroupId>,
/// Keeps track of disjoint sets of group ids.
disjoint_group_ids: DisjointSet<GroupId>,
property_builders: Arc<[Box<dyn PropertyBuilderAny<T>>]>,
}

Expand All @@ -93,17 +79,16 @@ impl<T: RelNodeTyp> Memo<T> {
expr_node_to_expr_id: HashMap::new(),
groups: HashMap::new(),
group_expr_counter: 0,
disjoint_groups: DisjointSet::new(),
disjoint_group_ids: DisjointSet::new(),
property_builders,
}
}

/// Get the next group id. Group id and expr id shares the same counter, so as to make it easier to debug...
fn next_group_id(&mut self) -> ReducedGroupId {
let id = self.group_expr_counter;
fn next_group_id(&mut self) -> GroupId {
let id = GroupId(self.group_expr_counter);
self.group_expr_counter += 1;
self.disjoint_groups.add(GroupId(id));
ReducedGroupId(id)
id
}

/// Get the next expr id. Group id and expr id shares the same counter, so as to make it easier to debug...
Expand All @@ -113,47 +98,32 @@ impl<T: RelNodeTyp> Memo<T> {
ExprId(id)
}

fn merge_group_inner(
&mut self,
group_a: ReducedGroupId,
group_b: ReducedGroupId,
) -> ReducedGroupId {
pub fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) -> GroupId {
if group_a == group_b {
return group_a;
}

let [rep, other] = self
.disjoint_groups
.union(&group_a.as_group_id(), &group_b.as_group_id())
.unwrap();
let [rep, other] = self.disjoint_group_ids.union(&group_a, &group_b).unwrap();

// Copy all expressions from group a to group b
// Copy all expressions from group other to its representative
let other_exprs = self.get_all_exprs_in_group(other);
for expr_id in other_exprs {
let expr_node = self.expr_id_to_expr_node.get(&expr_id).unwrap();
self.add_expr_to_group(expr_id, ReducedGroupId(rep.0), expr_node.as_ref().clone());
self.add_expr_to_group(expr_id, rep, expr_node.as_ref().clone());
}

// Remove all expressions from group a (so we don't accidentally access it)
self.clear_exprs_in_group(ReducedGroupId(other.0));
// Remove all expressions from group other (so we don't accidentally access it)
self.clear_exprs_in_group(other);

group_b
}

pub fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) -> GroupId {
let group_a_reduced = self.get_reduced_group_id(group_a);
let group_b_reduced = self.get_reduced_group_id(group_b);
self.merge_group_inner(group_a_reduced, group_b_reduced)
.as_group_id()
}

fn get_group_id_of_expr_id(&self, expr_id: ExprId) -> GroupId {
self.expr_id_to_group_id[&expr_id]
}

fn get_reduced_group_id(&self, group_id: GroupId) -> ReducedGroupId {
let reduced = self.disjoint_groups.find(&group_id).unwrap();
ReducedGroupId(reduced.0)
fn get_reduced_group_id(&self, group_id: GroupId) -> GroupId {
self.disjoint_group_ids.find(&group_id).unwrap()
}

/// Add or get an expression into the memo, returns the group id and the expr id. If `GroupId` is `None`,
Expand All @@ -170,7 +140,7 @@ impl<T: RelNodeTyp> Memo<T> {

let add_to_group_id = add_to_group_id.map(|x| self.get_reduced_group_id(x));
let (group_id, expr_id) = self.add_new_group_expr_inner(rel_node, add_to_group_id);
(group_id.as_group_id(), expr_id)
(group_id, expr_id)
}

pub fn get_expr_info(&self, rel_node: RelNodeRef<T>) -> (GroupId, ExprId) {
Expand Down Expand Up @@ -225,18 +195,13 @@ impl<T: RelNodeTyp> Memo<T> {
props
}

fn clear_exprs_in_group(&mut self, group_id: ReducedGroupId) {
fn clear_exprs_in_group(&mut self, group_id: GroupId) {
self.groups.remove(&group_id);
}

/// If group_id exists, it adds expr_id to the existing group
/// Otherwise, it creates a new group of that group_id and insert expr_id into the new group
fn add_expr_to_group(
&mut self,
expr_id: ExprId,
group_id: ReducedGroupId,
memo_node: RelMemoNode<T>,
) {
fn add_expr_to_group(&mut self, expr_id: ExprId, group_id: GroupId, memo_node: RelMemoNode<T>) {
if let Entry::Occupied(mut entry) = self.groups.entry(group_id) {
let group = entry.get_mut();
group.group_exprs.insert(expr_id);
Expand All @@ -248,6 +213,7 @@ impl<T: RelNodeTyp> Memo<T> {
properties: self.infer_properties(memo_node).into(),
};
group.group_exprs.insert(expr_id);
self.disjoint_group_ids.add(group_id);
self.groups.insert(group_id, group);
}

Expand Down Expand Up @@ -298,9 +264,8 @@ impl<T: RelNodeTyp> Memo<T> {
and make sure if it does not do any transformation, it should return an empty vec!");
}
let group_id = self.get_group_id_of_expr_id(new_expr_id);
let group_id = self.get_reduced_group_id(group_id);

self.merge_group_inner(replace_group_id, group_id);
self.merge_group(replace_group_id, group_id);
return false;
}

Expand All @@ -316,8 +281,8 @@ impl<T: RelNodeTyp> Memo<T> {
fn add_new_group_expr_inner(
&mut self,
rel_node: RelNodeRef<T>,
add_to_group_id: Option<ReducedGroupId>,
) -> (ReducedGroupId, ExprId) {
add_to_group_id: Option<GroupId>,
) -> (GroupId, ExprId) {
let children_group_ids = rel_node
.children
.iter()
Expand All @@ -338,7 +303,7 @@ impl<T: RelNodeTyp> Memo<T> {
let group_id = self.get_group_id_of_expr_id(expr_id);
let group_id = self.get_reduced_group_id(group_id);
if let Some(add_to_group_id) = add_to_group_id {
self.merge_group_inner(add_to_group_id, group_id);
self.merge_group(add_to_group_id, group_id);
}
return (group_id, expr_id);
}
Expand All @@ -350,8 +315,7 @@ impl<T: RelNodeTyp> Memo<T> {
};
self.expr_id_to_expr_node
.insert(expr_id, memo_node.clone().into());
self.expr_id_to_group_id
.insert(expr_id, group_id.as_group_id());
self.expr_id_to_group_id.insert(expr_id, group_id);
self.expr_node_to_expr_id.insert(memo_node.clone(), expr_id);
self.add_expr_to_group(expr_id, group_id, memo_node);
(group_id, expr_id)
Expand All @@ -364,7 +328,7 @@ impl<T: RelNodeTyp> Memo<T> {
.expr_id_to_group_id
.get(&expr_id)
.expect("expr not found in group mapping");
self.get_reduced_group_id(*group_id).as_group_id()
self.get_reduced_group_id(*group_id)
}

/// Get the memoized representation of a node.
Expand Down Expand Up @@ -465,12 +429,7 @@ impl<T: RelNodeTyp> Memo<T> {
}

pub fn get_all_group_ids(&self) -> Vec<GroupId> {
let mut ids = self
.groups
.keys()
.copied()
.map(|x| x.as_group_id())
.collect_vec();
let mut ids = self.groups.keys().copied().collect_vec();
ids.sort();
ids
}
Expand Down

0 comments on commit 3b83135

Please sign in to comment.