diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index 5def21ce..62ad52b7 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -1,4 +1,4 @@ -mod disjoint_set; +mod disjoint_group; use std::{ collections::{hash_map::Entry, HashMap, HashSet}, @@ -6,7 +6,7 @@ use std::{ }; use anyhow::{bail, Result}; -use disjoint_set::{DisjointSet, UnionFind}; +use disjoint_group::set::{DisjointSet, UnionFind}; use itertools::Itertools; use std::any::Any; diff --git a/optd-core/src/cascades/memo/disjoint_group.rs b/optd-core/src/cascades/memo/disjoint_group.rs new file mode 100644 index 00000000..d7e52559 --- /dev/null +++ b/optd-core/src/cascades/memo/disjoint_group.rs @@ -0,0 +1 @@ +pub mod set; diff --git a/optd-core/src/cascades/memo/disjoint_set.rs b/optd-core/src/cascades/memo/disjoint_group/set.rs similarity index 68% rename from optd-core/src/cascades/memo/disjoint_set.rs rename to optd-core/src/cascades/memo/disjoint_group/set.rs index 7989fc3d..35cc6b81 100644 --- a/optd-core/src/cascades/memo/disjoint_set.rs +++ b/optd-core/src/cascades/memo/disjoint_group/set.rs @@ -1,10 +1,4 @@ -use std::{ - collections::HashMap, - fmt::Debug, - hash::Hash, - ops::{Deref, DerefMut}, - sync::{atomic::AtomicUsize, Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, -}; +use std::{collections::HashMap, fmt::Debug, hash::Hash}; /// A data structure for efficiently maintaining disjoint sets of `T`. pub struct DisjointSet { @@ -17,9 +11,9 @@ pub struct DisjointSet { /// /// Alternatively, we could do no path compression at `find`, /// and only do path compression when we were doing union. - node_parents: Arc>>, + node_parents: HashMap, /// Number of disjoint sets. - num_sets: AtomicUsize, + num_sets: usize, } pub trait UnionFind @@ -31,11 +25,11 @@ where /// or `None` if one the node is not present. /// /// The smaller representative is selected as the new representative. - fn union(&self, a: &T, b: &T) -> Option<[T; 2]>; + fn union(&mut self, a: &T, b: &T) -> Option<[T; 2]>; /// Gets the representative node of the set that `node` is in. /// Path compression is performed while finding the representative. - fn find_path_compress(&self, node: &T) -> Option; + fn find_path_compress(&mut self, node: &T) -> Option; /// Gets the representative node of the set that `node` is in. fn find(&self, node: &T) -> Option; @@ -47,55 +41,48 @@ where { pub fn new() -> Self { DisjointSet { - node_parents: Arc::new(RwLock::new(HashMap::new())), - num_sets: AtomicUsize::new(0), + node_parents: HashMap::new(), + num_sets: 0, } } pub fn size(&self) -> usize { - let g = self.node_parents.read().unwrap(); - g.len() + self.node_parents.len() } pub fn num_sets(&self) -> usize { - self.num_sets.load(std::sync::atomic::Ordering::Relaxed) + self.num_sets } pub fn add(&mut self, node: T) { use std::collections::hash_map::Entry; - let mut g = self.node_parents.write().unwrap(); - if let Entry::Vacant(entry) = g.entry(node) { + if let Entry::Vacant(entry) = self.node_parents.entry(node) { entry.insert(node); - drop(g); - self.num_sets - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + self.num_sets += 1; } } - fn get_parent(g: &impl Deref>, node: &T) -> Option { - g.get(node).copied() + fn get_parent(&self, node: &T) -> Option { + self.node_parents.get(node).copied() } - fn set_parent(g: &mut impl DerefMut>, node: T, parent: T) { - g.insert(node, parent); + fn set_parent(&mut self, node: T, parent: T) { + self.node_parents.insert(node, parent); } /// Recursively find the parent of the `node` until reaching the representative of the set. /// A node is the representative if the its parent is the node itself. /// /// We utilize "path compression" to shorten the height of parent forest. - fn find_path_compress_inner( - g: &mut RwLockWriteGuard<'_, HashMap>, - node: &T, - ) -> Option { - let mut parent = Self::get_parent(g, node)?; + fn find_path_compress_inner(&mut self, node: &T) -> Option { + let mut parent = self.get_parent(node)?; if *node != parent { - parent = Self::find_path_compress_inner(g, &parent)?; + parent = self.find_path_compress_inner(&parent)?; // Path compression. - Self::set_parent(g, *node, parent); + self.set_parent(*node, parent); } Some(parent) @@ -103,11 +90,11 @@ where /// Recursively find the parent of the `node` until reaching the representative of the set. /// A node is the representative if the its parent is the node itself. - fn find_inner(g: &RwLockReadGuard<'_, HashMap>, node: &T) -> Option { - let mut parent = Self::get_parent(g, node)?; + fn find_inner(&self, node: &T) -> Option { + let mut parent = self.get_parent(node)?; if *node != parent { - parent = Self::find_inner(g, &parent)?; + parent = self.find_inner(&parent)?; } Some(parent) @@ -118,7 +105,7 @@ impl UnionFind for DisjointSet where T: Ord + Hash + Copy + Debug, { - fn union(&self, a: &T, b: &T) -> Option<[T; 2]> { + fn union(&mut self, a: &T, b: &T) -> Option<[T; 2]> { use std::cmp::Ordering; // Gets the represenatives for set containing `a`. @@ -127,20 +114,16 @@ where // Gets the represenatives for set containing `b`. let b_rep = self.find_path_compress(&b)?; - let mut g = self.node_parents.write().unwrap(); - // Node with smaller value becomes the representative. let res = match a_rep.cmp(&b_rep) { Ordering::Less => { - Self::set_parent(&mut g, b_rep, a_rep); - self.num_sets - .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.set_parent(b_rep, a_rep); + self.num_sets -= 1; [a_rep, b_rep] } Ordering::Greater => { - Self::set_parent(&mut g, a_rep, b_rep); - self.num_sets - .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.set_parent(a_rep, b_rep); + self.num_sets -= 1; [b_rep, a_rep] } Ordering::Equal => [a_rep, b_rep], @@ -149,15 +132,13 @@ where } /// See [`Self::find_inner`] for implementation detail. - fn find_path_compress(&self, node: &T) -> Option { - let mut g = self.node_parents.write().unwrap(); - Self::find_path_compress_inner(&mut g, node) + fn find_path_compress(&mut self, node: &T) -> Option { + self.find_path_compress_inner(node) } /// See [`Self::find_inner`] for implementation detail. fn find(&self, node: &T) -> Option { - let g = self.node_parents.read().unwrap(); - Self::find_inner(&g, node) + self.find_inner(node) } }