From 988bea632296bac556276aa7833bc6791af8aa2a Mon Sep 17 00:00:00 2001 From: Truong Nhan Nguyen <80200848+sozelfist@users.noreply.github.com> Date: Thu, 7 Nov 2024 04:25:20 +0700 Subject: [PATCH] Refactor Minimum Spanning Tree (#837) * ref: refactor mst * chore: fix clippy --- src/graph/minimum_spanning_tree.rs | 241 ++++++++++++++++------------- 1 file changed, 131 insertions(+), 110 deletions(-) diff --git a/src/graph/minimum_spanning_tree.rs b/src/graph/minimum_spanning_tree.rs index 5efcb625e0c..9d36cafb303 100644 --- a/src/graph/minimum_spanning_tree.rs +++ b/src/graph/minimum_spanning_tree.rs @@ -1,24 +1,22 @@ -use super::DisjointSetUnion; +//! This module implements Kruskal's algorithm to find the Minimum Spanning Tree (MST) +//! of an undirected, weighted graph using a Disjoint Set Union (DSU) for cycle detection. -#[derive(Debug)] -pub struct Edge { - source: i64, - destination: i64, - cost: i64, -} +use crate::graph::DisjointSetUnion; -impl PartialEq for Edge { - fn eq(&self, other: &Self) -> bool { - self.source == other.source - && self.destination == other.destination - && self.cost == other.cost - } +/// Represents an edge in the graph with a source, destination, and associated cost. +#[derive(Debug, PartialEq, Eq)] +pub struct Edge { + /// The starting vertex of the edge. + source: usize, + /// The ending vertex of the edge. + destination: usize, + /// The cost associated with the edge. + cost: usize, } -impl Eq for Edge {} - impl Edge { - fn new(source: i64, destination: i64, cost: i64) -> Self { + /// Creates a new edge with the specified source, destination, and cost. + pub fn new(source: usize, destination: usize, cost: usize) -> Self { Self { source, destination, @@ -27,112 +25,135 @@ impl Edge { } } -pub fn kruskal(mut edges: Vec, number_of_vertices: i64) -> (i64, Vec) { - let mut dsu = DisjointSetUnion::new(number_of_vertices as usize); - - edges.sort_unstable_by(|a, b| a.cost.cmp(&b.cost)); - let mut total_cost: i64 = 0; - let mut final_edges: Vec = Vec::new(); - let mut merge_count: i64 = 0; - for edge in edges.iter() { - if merge_count >= number_of_vertices - 1 { +/// Executes Kruskal's algorithm to compute the Minimum Spanning Tree (MST) of a graph. +/// +/// # Parameters +/// +/// - `edges`: A vector of `Edge` instances representing all edges in the graph. +/// - `num_vertices`: The total number of vertices in the graph. +/// +/// # Returns +/// +/// An `Option` containing a tuple with: +/// +/// - The total cost of the MST (usize). +/// - A vector of edges that are included in the MST. +/// +/// Returns `None` if the graph is disconnected. +/// +/// # Complexity +/// +/// The time complexity is O(E log E), where E is the number of edges. +pub fn kruskal(mut edges: Vec, num_vertices: usize) -> Option<(usize, Vec)> { + let mut dsu = DisjointSetUnion::new(num_vertices); + let mut mst_cost: usize = 0; + let mut mst_edges: Vec = Vec::with_capacity(num_vertices - 1); + + // Sort edges by cost in ascending order + edges.sort_unstable_by_key(|edge| edge.cost); + + for edge in edges { + if mst_edges.len() == num_vertices - 1 { break; } - let source: i64 = edge.source; - let destination: i64 = edge.destination; - if dsu.merge(source as usize, destination as usize) < usize::MAX { - merge_count += 1; - let cost: i64 = edge.cost; - total_cost += cost; - let final_edge: Edge = Edge::new(source, destination, cost); - final_edges.push(final_edge); + // Attempt to merge the sets containing the edge’s vertices + if dsu.merge(edge.source, edge.destination) != usize::MAX { + mst_cost += edge.cost; + mst_edges.push(edge); } } - (total_cost, final_edges) + + // Return MST if it includes exactly num_vertices - 1 edges, otherwise None for disconnected graphs + (mst_edges.len() == num_vertices - 1).then_some((mst_cost, mst_edges)) } #[cfg(test)] mod tests { use super::*; - #[test] - fn test_seven_vertices_eleven_edges() { - let edges = vec![ - Edge::new(0, 1, 7), - Edge::new(0, 3, 5), - Edge::new(1, 2, 8), - Edge::new(1, 3, 9), - Edge::new(1, 4, 7), - Edge::new(2, 4, 5), - Edge::new(3, 4, 15), - Edge::new(3, 5, 6), - Edge::new(4, 5, 8), - Edge::new(4, 6, 9), - Edge::new(5, 6, 11), - ]; - - let number_of_vertices: i64 = 7; - - let expected_total_cost = 39; - let expected_used_edges = vec![ - Edge::new(0, 3, 5), - Edge::new(2, 4, 5), - Edge::new(3, 5, 6), - Edge::new(0, 1, 7), - Edge::new(1, 4, 7), - Edge::new(4, 6, 9), - ]; - - let (actual_total_cost, actual_final_edges) = kruskal(edges, number_of_vertices); - - assert_eq!(actual_total_cost, expected_total_cost); - assert_eq!(actual_final_edges, expected_used_edges); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (edges, num_vertices, expected_result) = $test_case; + let actual_result = kruskal(edges, num_vertices); + assert_eq!(actual_result, expected_result); + } + )* + }; } - #[test] - fn test_ten_vertices_twenty_edges() { - let edges = vec![ - Edge::new(0, 1, 3), - Edge::new(0, 3, 6), - Edge::new(0, 4, 9), - Edge::new(1, 2, 2), - Edge::new(1, 3, 4), - Edge::new(1, 4, 9), - Edge::new(2, 3, 2), - Edge::new(2, 5, 8), - Edge::new(2, 6, 9), - Edge::new(3, 6, 9), - Edge::new(4, 5, 8), - Edge::new(4, 9, 18), - Edge::new(5, 6, 7), - Edge::new(5, 8, 9), - Edge::new(5, 9, 10), - Edge::new(6, 7, 4), - Edge::new(6, 8, 5), - Edge::new(7, 8, 1), - Edge::new(7, 9, 4), - Edge::new(8, 9, 3), - ]; - - let number_of_vertices: i64 = 10; - - let expected_total_cost = 38; - let expected_used_edges = vec![ - Edge::new(7, 8, 1), - Edge::new(1, 2, 2), - Edge::new(2, 3, 2), - Edge::new(0, 1, 3), - Edge::new(8, 9, 3), - Edge::new(6, 7, 4), - Edge::new(5, 6, 7), - Edge::new(2, 5, 8), - Edge::new(4, 5, 8), - ]; - - let (actual_total_cost, actual_final_edges) = kruskal(edges, number_of_vertices); - - assert_eq!(actual_total_cost, expected_total_cost); - assert_eq!(actual_final_edges, expected_used_edges); + test_cases! { + test_seven_vertices_eleven_edges: ( + vec![ + Edge::new(0, 1, 7), + Edge::new(0, 3, 5), + Edge::new(1, 2, 8), + Edge::new(1, 3, 9), + Edge::new(1, 4, 7), + Edge::new(2, 4, 5), + Edge::new(3, 4, 15), + Edge::new(3, 5, 6), + Edge::new(4, 5, 8), + Edge::new(4, 6, 9), + Edge::new(5, 6, 11), + ], + 7, + Some((39, vec![ + Edge::new(0, 3, 5), + Edge::new(2, 4, 5), + Edge::new(3, 5, 6), + Edge::new(0, 1, 7), + Edge::new(1, 4, 7), + Edge::new(4, 6, 9), + ])) + ), + test_ten_vertices_twenty_edges: ( + vec![ + Edge::new(0, 1, 3), + Edge::new(0, 3, 6), + Edge::new(0, 4, 9), + Edge::new(1, 2, 2), + Edge::new(1, 3, 4), + Edge::new(1, 4, 9), + Edge::new(2, 3, 2), + Edge::new(2, 5, 8), + Edge::new(2, 6, 9), + Edge::new(3, 6, 9), + Edge::new(4, 5, 8), + Edge::new(4, 9, 18), + Edge::new(5, 6, 7), + Edge::new(5, 8, 9), + Edge::new(5, 9, 10), + Edge::new(6, 7, 4), + Edge::new(6, 8, 5), + Edge::new(7, 8, 1), + Edge::new(7, 9, 4), + Edge::new(8, 9, 3), + ], + 10, + Some((38, vec![ + Edge::new(7, 8, 1), + Edge::new(1, 2, 2), + Edge::new(2, 3, 2), + Edge::new(0, 1, 3), + Edge::new(8, 9, 3), + Edge::new(6, 7, 4), + Edge::new(5, 6, 7), + Edge::new(2, 5, 8), + Edge::new(4, 5, 8), + ])) + ), + test_disconnected_graph: ( + vec![ + Edge::new(0, 1, 4), + Edge::new(0, 2, 6), + Edge::new(3, 4, 2), + ], + 5, + None + ), } }