diff --git a/src/lib.rs b/src/lib.rs index 59108b8..f90502f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,19 +1,19 @@ // Standard library imports use std::collections::hash_map::IntoIter; use std::collections::{HashMap, HashSet}; +use std::fs::File; +use std::io::{BufWriter, Write}; // External crate imports use anyhow::{anyhow, Result}; use log::debug; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::PyResult; +use rayon::prelude::*; use sourmash::encodings::HashFunctions; use sourmash::signature::SeqToHashes; -use pyo3::PyResult; -use std::fs::File; -use std::io::{BufWriter, Write}; - // Set version variable const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -440,6 +440,55 @@ impl KmerCountTable { self.counts.insert(hashval, count); Ok(()) } + + // Jaccard + + /// Cosine similarity between two `KmerCountTable` objects. + /// + /// # Arguments + /// * `other` - The second `KmerCountTable` to compare against. + /// + /// # Returns + /// The cosine similarity between the two tables as a float value between 0 and 1. + pub fn cosine(&self, other: &KmerCountTable) -> f64 { + // Early return if either table is empty. + if self.counts.is_empty() || other.counts.is_empty() { + return 0.0; + } + + // Calculate the dot product in parallel. + let dot_product: u64 = self + .counts + .par_iter() + .filter_map(|(&hash, &count1)| { + // Only include in the dot product if both tables have the k-mer. + other.counts.get(&hash).map(|&count2| count1 * count2) + }) + .sum(); + + // Calculate magnitudes in parallel for both tables. + let magnitude_self: f64 = self + .counts + .par_iter() + .map(|(_, v)| (*v as f64).powi(2)) // Access the value, square it + .sum::() + .sqrt(); + + let magnitude_other: f64 = other + .counts + .par_iter() + .map(|(_, v)| (*v as f64).powi(2)) // Access the value, square it + .sum::() + .sqrt(); + + // If either magnitude is zero (no k-mers), return 0 to avoid division by zero. + if magnitude_self == 0.0 || magnitude_other == 0.0 { + return 0.0; + } + + // Calculate and return cosine similarity. + dot_product as f64 / (magnitude_self * magnitude_other) + } } // Iterator implementation for KmerCountTable