-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ref: refactor disjoin set union - replaced `reserve_exact()` with `Vec::with_capacity()`, a more idiomatic approach for reserving memory for the nodes vector. - add doccumentations - rewrite tests * style: use `Self` https://rust-lang.github.io/rust-clippy/master/index.html#indexing_slicing --------- Co-authored-by: Piotr Idzik <[email protected]>
- Loading branch information
Showing
1 changed file
with
121 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,148 @@ | ||
//! This module implements the Disjoint Set Union (DSU), also known as Union-Find, | ||
//! which is an efficient data structure for keeping track of a set of elements | ||
//! partitioned into disjoint (non-overlapping) subsets. | ||
|
||
/// Represents a node in the Disjoint Set Union (DSU) structure which | ||
/// keep track of the parent-child relationships in the disjoint sets. | ||
pub struct DSUNode { | ||
/// The index of the node's parent, or itself if it's the root. | ||
parent: usize, | ||
/// The size of the set rooted at this node, used for union by size. | ||
size: usize, | ||
} | ||
|
||
/// Disjoint Set Union (Union-Find) data structure, particularly useful for | ||
/// managing dynamic connectivity problems such as determining | ||
/// if two elements are in the same subset or merging two subsets. | ||
pub struct DisjointSetUnion { | ||
/// List of DSU nodes where each element's parent and size are tracked. | ||
nodes: Vec<DSUNode>, | ||
} | ||
|
||
// We are using both path compression and union by size | ||
impl DisjointSetUnion { | ||
// Create n+1 sets [0, n] | ||
pub fn new(n: usize) -> DisjointSetUnion { | ||
let mut nodes = Vec::new(); | ||
nodes.reserve_exact(n + 1); | ||
for i in 0..=n { | ||
nodes.push(DSUNode { parent: i, size: 1 }); | ||
/// Initializes `n + 1` disjoint sets, each element is its own parent. | ||
/// | ||
/// # Parameters | ||
/// | ||
/// - `n`: The number of elements to manage (`0` to `n` inclusive). | ||
/// | ||
/// # Returns | ||
/// | ||
/// A new instance of `DisjointSetUnion` with `n + 1` independent sets. | ||
pub fn new(num_elements: usize) -> DisjointSetUnion { | ||
let mut nodes = Vec::with_capacity(num_elements + 1); | ||
for idx in 0..=num_elements { | ||
nodes.push(DSUNode { | ||
parent: idx, | ||
size: 1, | ||
}); | ||
} | ||
DisjointSetUnion { nodes } | ||
|
||
Self { nodes } | ||
} | ||
pub fn find_set(&mut self, v: usize) -> usize { | ||
if v == self.nodes[v].parent { | ||
return v; | ||
|
||
/// Finds the representative (root) of the set containing `element` with path compression. | ||
/// | ||
/// Path compression ensures that future queries are faster by directly linking | ||
/// all nodes in the path to the root. | ||
/// | ||
/// # Parameters | ||
/// | ||
/// - `element`: The element whose set representative is being found. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The root representative of the set containing `element`. | ||
pub fn find_set(&mut self, element: usize) -> usize { | ||
if element != self.nodes[element].parent { | ||
self.nodes[element].parent = self.find_set(self.nodes[element].parent); | ||
} | ||
self.nodes[v].parent = self.find_set(self.nodes[v].parent); | ||
self.nodes[v].parent | ||
self.nodes[element].parent | ||
} | ||
// Returns the new component of the merged sets, | ||
// or usize::MAX if they were the same. | ||
pub fn merge(&mut self, u: usize, v: usize) -> usize { | ||
let mut a = self.find_set(u); | ||
let mut b = self.find_set(v); | ||
if a == b { | ||
|
||
/// Merges the sets containing `first_elem` and `sec_elem` using union by size. | ||
/// | ||
/// The smaller set is always attached to the root of the larger set to ensure balanced trees. | ||
/// | ||
/// # Parameters | ||
/// | ||
/// - `first_elem`: The first element whose set is to be merged. | ||
/// - `sec_elem`: The second element whose set is to be merged. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The root of the merged set, or `usize::MAX` if both elements are already in the same set. | ||
pub fn merge(&mut self, first_elem: usize, sec_elem: usize) -> usize { | ||
let mut first_root = self.find_set(first_elem); | ||
let mut sec_root = self.find_set(sec_elem); | ||
|
||
if first_root == sec_root { | ||
// Already in the same set, no merge required | ||
return usize::MAX; | ||
} | ||
if self.nodes[a].size < self.nodes[b].size { | ||
std::mem::swap(&mut a, &mut b); | ||
|
||
// Union by size: attach the smaller tree under the larger tree | ||
if self.nodes[first_root].size < self.nodes[sec_root].size { | ||
std::mem::swap(&mut first_root, &mut sec_root); | ||
} | ||
self.nodes[b].parent = a; | ||
self.nodes[a].size += self.nodes[b].size; | ||
a | ||
|
||
self.nodes[sec_root].parent = first_root; | ||
self.nodes[first_root].size += self.nodes[sec_root].size; | ||
|
||
first_root | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn create_acyclic_graph() { | ||
fn test_disjoint_set_union() { | ||
let mut dsu = DisjointSetUnion::new(10); | ||
// Add edges such that vertices 1..=9 are connected | ||
// and vertex 10 is not connected to the other ones | ||
let edges: Vec<(usize, usize)> = vec![ | ||
(1, 2), // + | ||
(2, 1), | ||
(2, 3), // + | ||
(1, 3), | ||
(4, 5), // + | ||
(7, 8), // + | ||
(4, 8), // + | ||
(3, 8), // + | ||
(1, 9), // + | ||
(2, 9), | ||
(3, 9), | ||
(4, 9), | ||
(5, 9), | ||
(6, 9), // + | ||
(7, 9), | ||
]; | ||
let expected_edges: Vec<(usize, usize)> = vec![ | ||
(1, 2), | ||
(2, 3), | ||
(4, 5), | ||
(7, 8), | ||
(4, 8), | ||
(3, 8), | ||
(1, 9), | ||
(6, 9), | ||
]; | ||
let mut added_edges: Vec<(usize, usize)> = Vec::new(); | ||
for (u, v) in edges { | ||
if dsu.merge(u, v) < usize::MAX { | ||
added_edges.push((u, v)); | ||
} | ||
// Now they should be the same | ||
assert!(dsu.merge(u, v) == usize::MAX); | ||
} | ||
assert_eq!(added_edges, expected_edges); | ||
let comp_1 = dsu.find_set(1); | ||
for i in 2..=9 { | ||
assert_eq!(comp_1, dsu.find_set(i)); | ||
} | ||
assert_ne!(comp_1, dsu.find_set(10)); | ||
|
||
dsu.merge(1, 2); | ||
dsu.merge(2, 3); | ||
dsu.merge(1, 9); | ||
dsu.merge(4, 5); | ||
dsu.merge(7, 8); | ||
dsu.merge(4, 8); | ||
dsu.merge(6, 9); | ||
|
||
assert_eq!(dsu.find_set(1), dsu.find_set(2)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(3)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(6)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(9)); | ||
|
||
assert_eq!(dsu.find_set(4), dsu.find_set(5)); | ||
assert_eq!(dsu.find_set(4), dsu.find_set(7)); | ||
assert_eq!(dsu.find_set(4), dsu.find_set(8)); | ||
|
||
assert_ne!(dsu.find_set(1), dsu.find_set(10)); | ||
assert_ne!(dsu.find_set(4), dsu.find_set(10)); | ||
|
||
dsu.merge(3, 4); | ||
|
||
assert_eq!(dsu.find_set(1), dsu.find_set(2)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(3)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(6)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(9)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(4)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(5)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(7)); | ||
assert_eq!(dsu.find_set(1), dsu.find_set(8)); | ||
|
||
assert_ne!(dsu.find_set(1), dsu.find_set(10)); | ||
|
||
dsu.merge(10, 1); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(1)); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(2)); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(3)); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(4)); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(5)); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(6)); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(7)); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(8)); | ||
assert_eq!(dsu.find_set(10), dsu.find_set(9)); | ||
} | ||
} |