Skip to content

Commit

Permalink
remove sync on set itself
Browse files Browse the repository at this point in the history
Signed-off-by: Yuchen Liang <[email protected]>
  • Loading branch information
yliang412 committed Oct 14, 2024
1 parent 3b83135 commit fd1f49e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 51 deletions.
4 changes: 2 additions & 2 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
mod disjoint_set;
mod disjoint_group;

use std::{
collections::{hash_map::Entry, HashMap, HashSet},
sync::Arc,
};

use anyhow::{bail, Result};
use disjoint_set::{DisjointSet, UnionFind};
use disjoint_group::set::{DisjointSet, UnionFind};
use itertools::Itertools;
use std::any::Any;

Expand Down
1 change: 1 addition & 0 deletions optd-core/src/cascades/memo/disjoint_group.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod set;
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
use std::{
collections::HashMap,
fmt::Debug,
hash::Hash,
ops::{Deref, DerefMut},
sync::{atomic::AtomicUsize, Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
};
use std::{collections::HashMap, fmt::Debug, hash::Hash};

/// A data structure for efficiently maintaining disjoint sets of `T`.
pub struct DisjointSet<T> {
Expand All @@ -17,9 +11,9 @@ pub struct DisjointSet<T> {
///
/// Alternatively, we could do no path compression at `find`,
/// and only do path compression when we were doing union.
node_parents: Arc<RwLock<HashMap<T, T>>>,
node_parents: HashMap<T, T>,
/// Number of disjoint sets.
num_sets: AtomicUsize,
num_sets: usize,
}

pub trait UnionFind<T>
Expand All @@ -31,11 +25,11 @@ where
/// or `None` if one the node is not present.
///
/// The smaller representative is selected as the new representative.
fn union(&self, a: &T, b: &T) -> Option<[T; 2]>;
fn union(&mut self, a: &T, b: &T) -> Option<[T; 2]>;

/// Gets the representative node of the set that `node` is in.
/// Path compression is performed while finding the representative.
fn find_path_compress(&self, node: &T) -> Option<T>;
fn find_path_compress(&mut self, node: &T) -> Option<T>;

/// Gets the representative node of the set that `node` is in.
fn find(&self, node: &T) -> Option<T>;
Expand All @@ -47,67 +41,60 @@ where
{
pub fn new() -> Self {
DisjointSet {
node_parents: Arc::new(RwLock::new(HashMap::new())),
num_sets: AtomicUsize::new(0),
node_parents: HashMap::new(),
num_sets: 0,
}
}

pub fn size(&self) -> usize {
let g = self.node_parents.read().unwrap();
g.len()
self.node_parents.len()
}

pub fn num_sets(&self) -> usize {
self.num_sets.load(std::sync::atomic::Ordering::Relaxed)
self.num_sets
}

pub fn add(&mut self, node: T) {
use std::collections::hash_map::Entry;

let mut g = self.node_parents.write().unwrap();
if let Entry::Vacant(entry) = g.entry(node) {
if let Entry::Vacant(entry) = self.node_parents.entry(node) {
entry.insert(node);
drop(g);
self.num_sets
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.num_sets += 1;
}
}

fn get_parent(g: &impl Deref<Target = HashMap<T, T>>, node: &T) -> Option<T> {
g.get(node).copied()
fn get_parent(&self, node: &T) -> Option<T> {
self.node_parents.get(node).copied()
}

fn set_parent(g: &mut impl DerefMut<Target = HashMap<T, T>>, node: T, parent: T) {
g.insert(node, parent);
fn set_parent(&mut self, node: T, parent: T) {
self.node_parents.insert(node, parent);
}

/// Recursively find the parent of the `node` until reaching the representative of the set.
/// A node is the representative if the its parent is the node itself.
///
/// We utilize "path compression" to shorten the height of parent forest.
fn find_path_compress_inner(
g: &mut RwLockWriteGuard<'_, HashMap<T, T>>,
node: &T,
) -> Option<T> {
let mut parent = Self::get_parent(g, node)?;
fn find_path_compress_inner(&mut self, node: &T) -> Option<T> {
let mut parent = self.get_parent(node)?;

if *node != parent {
parent = Self::find_path_compress_inner(g, &parent)?;
parent = self.find_path_compress_inner(&parent)?;

// Path compression.
Self::set_parent(g, *node, parent);
self.set_parent(*node, parent);
}

Some(parent)
}

/// Recursively find the parent of the `node` until reaching the representative of the set.
/// A node is the representative if the its parent is the node itself.
fn find_inner(g: &RwLockReadGuard<'_, HashMap<T, T>>, node: &T) -> Option<T> {
let mut parent = Self::get_parent(g, node)?;
fn find_inner(&self, node: &T) -> Option<T> {
let mut parent = self.get_parent(node)?;

if *node != parent {
parent = Self::find_inner(g, &parent)?;
parent = self.find_inner(&parent)?;
}

Some(parent)
Expand All @@ -118,7 +105,7 @@ impl<T> UnionFind<T> for DisjointSet<T>
where
T: Ord + Hash + Copy + Debug,
{
fn union(&self, a: &T, b: &T) -> Option<[T; 2]> {
fn union(&mut self, a: &T, b: &T) -> Option<[T; 2]> {
use std::cmp::Ordering;

// Gets the represenatives for set containing `a`.
Expand All @@ -127,20 +114,16 @@ where
// Gets the represenatives for set containing `b`.
let b_rep = self.find_path_compress(&b)?;

let mut g = self.node_parents.write().unwrap();

// Node with smaller value becomes the representative.
let res = match a_rep.cmp(&b_rep) {
Ordering::Less => {
Self::set_parent(&mut g, b_rep, a_rep);
self.num_sets
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
self.set_parent(b_rep, a_rep);
self.num_sets -= 1;
[a_rep, b_rep]
}
Ordering::Greater => {
Self::set_parent(&mut g, a_rep, b_rep);
self.num_sets
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
self.set_parent(a_rep, b_rep);
self.num_sets -= 1;
[b_rep, a_rep]
}
Ordering::Equal => [a_rep, b_rep],
Expand All @@ -149,15 +132,13 @@ where
}

/// See [`Self::find_inner`] for implementation detail.
fn find_path_compress(&self, node: &T) -> Option<T> {
let mut g = self.node_parents.write().unwrap();
Self::find_path_compress_inner(&mut g, node)
fn find_path_compress(&mut self, node: &T) -> Option<T> {
self.find_path_compress_inner(node)
}

/// See [`Self::find_inner`] for implementation detail.
fn find(&self, node: &T) -> Option<T> {
let g = self.node_parents.read().unwrap();
Self::find_inner(&g, node)
self.find_inner(node)
}
}

Expand Down

0 comments on commit fd1f49e

Please sign in to comment.