Skip to content

Commit

Permalink
Refactor Disjoint Set Union (#802)
Browse files Browse the repository at this point in the history
* ref: refactor disjoin set union
- replaced `reserve_exact()` with `Vec::with_capacity()`, a more idiomatic approach for reserving memory for the nodes vector.
- add doccumentations
- rewrite tests

* style: use `Self`

https://rust-lang.github.io/rust-clippy/master/index.html#indexing_slicing

---------

Co-authored-by: Piotr Idzik <[email protected]>
  • Loading branch information
sozelfist and vil02 authored Oct 3, 2024
1 parent bc8d6fa commit be27f2c
Showing 1 changed file with 121 additions and 68 deletions.
189 changes: 121 additions & 68 deletions src/graph/disjoint_set_union.rs
Original file line number Diff line number Diff line change
@@ -1,95 +1,148 @@
//! This module implements the Disjoint Set Union (DSU), also known as Union-Find,
//! which is an efficient data structure for keeping track of a set of elements
//! partitioned into disjoint (non-overlapping) subsets.

/// Represents a node in the Disjoint Set Union (DSU) structure which
/// keep track of the parent-child relationships in the disjoint sets.
pub struct DSUNode {
/// The index of the node's parent, or itself if it's the root.
parent: usize,
/// The size of the set rooted at this node, used for union by size.
size: usize,
}

/// Disjoint Set Union (Union-Find) data structure, particularly useful for
/// managing dynamic connectivity problems such as determining
/// if two elements are in the same subset or merging two subsets.
pub struct DisjointSetUnion {
/// List of DSU nodes where each element's parent and size are tracked.
nodes: Vec<DSUNode>,
}

// We are using both path compression and union by size
impl DisjointSetUnion {
// Create n+1 sets [0, n]
pub fn new(n: usize) -> DisjointSetUnion {
let mut nodes = Vec::new();
nodes.reserve_exact(n + 1);
for i in 0..=n {
nodes.push(DSUNode { parent: i, size: 1 });
/// Initializes `n + 1` disjoint sets, each element is its own parent.
///
/// # Parameters
///
/// - `n`: The number of elements to manage (`0` to `n` inclusive).
///
/// # Returns
///
/// A new instance of `DisjointSetUnion` with `n + 1` independent sets.
pub fn new(num_elements: usize) -> DisjointSetUnion {
let mut nodes = Vec::with_capacity(num_elements + 1);
for idx in 0..=num_elements {
nodes.push(DSUNode {
parent: idx,
size: 1,
});
}
DisjointSetUnion { nodes }

Self { nodes }
}
pub fn find_set(&mut self, v: usize) -> usize {
if v == self.nodes[v].parent {
return v;

/// Finds the representative (root) of the set containing `element` with path compression.
///
/// Path compression ensures that future queries are faster by directly linking
/// all nodes in the path to the root.
///
/// # Parameters
///
/// - `element`: The element whose set representative is being found.
///
/// # Returns
///
/// The root representative of the set containing `element`.
pub fn find_set(&mut self, element: usize) -> usize {
if element != self.nodes[element].parent {
self.nodes[element].parent = self.find_set(self.nodes[element].parent);
}
self.nodes[v].parent = self.find_set(self.nodes[v].parent);
self.nodes[v].parent
self.nodes[element].parent
}
// Returns the new component of the merged sets,
// or usize::MAX if they were the same.
pub fn merge(&mut self, u: usize, v: usize) -> usize {
let mut a = self.find_set(u);
let mut b = self.find_set(v);
if a == b {

/// Merges the sets containing `first_elem` and `sec_elem` using union by size.
///
/// The smaller set is always attached to the root of the larger set to ensure balanced trees.
///
/// # Parameters
///
/// - `first_elem`: The first element whose set is to be merged.
/// - `sec_elem`: The second element whose set is to be merged.
///
/// # Returns
///
/// The root of the merged set, or `usize::MAX` if both elements are already in the same set.
pub fn merge(&mut self, first_elem: usize, sec_elem: usize) -> usize {
let mut first_root = self.find_set(first_elem);
let mut sec_root = self.find_set(sec_elem);

if first_root == sec_root {
// Already in the same set, no merge required
return usize::MAX;
}
if self.nodes[a].size < self.nodes[b].size {
std::mem::swap(&mut a, &mut b);

// Union by size: attach the smaller tree under the larger tree
if self.nodes[first_root].size < self.nodes[sec_root].size {
std::mem::swap(&mut first_root, &mut sec_root);
}
self.nodes[b].parent = a;
self.nodes[a].size += self.nodes[b].size;
a

self.nodes[sec_root].parent = first_root;
self.nodes[first_root].size += self.nodes[sec_root].size;

first_root
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn create_acyclic_graph() {
fn test_disjoint_set_union() {
let mut dsu = DisjointSetUnion::new(10);
// Add edges such that vertices 1..=9 are connected
// and vertex 10 is not connected to the other ones
let edges: Vec<(usize, usize)> = vec![
(1, 2), // +
(2, 1),
(2, 3), // +
(1, 3),
(4, 5), // +
(7, 8), // +
(4, 8), // +
(3, 8), // +
(1, 9), // +
(2, 9),
(3, 9),
(4, 9),
(5, 9),
(6, 9), // +
(7, 9),
];
let expected_edges: Vec<(usize, usize)> = vec![
(1, 2),
(2, 3),
(4, 5),
(7, 8),
(4, 8),
(3, 8),
(1, 9),
(6, 9),
];
let mut added_edges: Vec<(usize, usize)> = Vec::new();
for (u, v) in edges {
if dsu.merge(u, v) < usize::MAX {
added_edges.push((u, v));
}
// Now they should be the same
assert!(dsu.merge(u, v) == usize::MAX);
}
assert_eq!(added_edges, expected_edges);
let comp_1 = dsu.find_set(1);
for i in 2..=9 {
assert_eq!(comp_1, dsu.find_set(i));
}
assert_ne!(comp_1, dsu.find_set(10));

dsu.merge(1, 2);
dsu.merge(2, 3);
dsu.merge(1, 9);
dsu.merge(4, 5);
dsu.merge(7, 8);
dsu.merge(4, 8);
dsu.merge(6, 9);

assert_eq!(dsu.find_set(1), dsu.find_set(2));
assert_eq!(dsu.find_set(1), dsu.find_set(3));
assert_eq!(dsu.find_set(1), dsu.find_set(6));
assert_eq!(dsu.find_set(1), dsu.find_set(9));

assert_eq!(dsu.find_set(4), dsu.find_set(5));
assert_eq!(dsu.find_set(4), dsu.find_set(7));
assert_eq!(dsu.find_set(4), dsu.find_set(8));

assert_ne!(dsu.find_set(1), dsu.find_set(10));
assert_ne!(dsu.find_set(4), dsu.find_set(10));

dsu.merge(3, 4);

assert_eq!(dsu.find_set(1), dsu.find_set(2));
assert_eq!(dsu.find_set(1), dsu.find_set(3));
assert_eq!(dsu.find_set(1), dsu.find_set(6));
assert_eq!(dsu.find_set(1), dsu.find_set(9));
assert_eq!(dsu.find_set(1), dsu.find_set(4));
assert_eq!(dsu.find_set(1), dsu.find_set(5));
assert_eq!(dsu.find_set(1), dsu.find_set(7));
assert_eq!(dsu.find_set(1), dsu.find_set(8));

assert_ne!(dsu.find_set(1), dsu.find_set(10));

dsu.merge(10, 1);
assert_eq!(dsu.find_set(10), dsu.find_set(1));
assert_eq!(dsu.find_set(10), dsu.find_set(2));
assert_eq!(dsu.find_set(10), dsu.find_set(3));
assert_eq!(dsu.find_set(10), dsu.find_set(4));
assert_eq!(dsu.find_set(10), dsu.find_set(5));
assert_eq!(dsu.find_set(10), dsu.find_set(6));
assert_eq!(dsu.find_set(10), dsu.find_set(7));
assert_eq!(dsu.find_set(10), dsu.find_set(8));
assert_eq!(dsu.find_set(10), dsu.find_set(9));
}
}

0 comments on commit be27f2c

Please sign in to comment.