Skip to content

Commit

Permalink
Init add()
Browse files Browse the repository at this point in the history
  • Loading branch information
Adamtaranto committed Oct 2, 2024
1 parent 4abe6c2 commit 5942485
Showing 1 changed file with 148 additions and 0 deletions.
148 changes: 148 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::hash_map::IntoIter;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{BufReader, BufWriter, Write};
use std::sync::atomic::{AtomicU64, Ordering};
//use std::path::Path;

// External crate imports
Expand Down Expand Up @@ -759,6 +760,153 @@ impl KmerCountTable {
// Calculate and return cosine similarity.
dot_product as f64 / (magnitude_self * magnitude_other)
}

/// Add counts from another KmerCountTable to this one.
///
/// # Arguments
///
/// * `other` - The KmerCountTable to add from
///
/// # Returns
///
/// Returns a PyResult with a tuple containing:
/// * The number of k-mer counts added
/// * The number of new keys added to hash_to_kmer (if store_kmers is true)
#[pyo3(signature = (other))]
pub fn add(&mut self, other: &KmerCountTable) -> PyResult<(u64, u64)> {
// Check if ksizes match
if self.ksize != other.ksize {
return Err(PyValueError::new_err(
"KmerCountTables must have the same ksize",
));
}

// Use atomic counters for thread-safe updates
let total_counts_added = AtomicU64::new(0);
let new_keys_added = AtomicU64::new(0);

// Create a vector to store updates
let updates: Vec<_> = other
.counts
.par_iter()
.map(|(&hash, &count)| (hash, count))
.collect();

// Apply updates sequentially
for (hash, count) in updates {
let current_count = self.counts.entry(hash).or_insert(0);
*current_count += count;
total_counts_added.fetch_add(count, Ordering::Relaxed);
}

// Update consumed bases
self.consumed += other.consumed;

// Handle hash_to_kmer updates if store_kmers is true
if self.store_kmers {
if other.store_kmers {
// Both tables have store_kmers = true, so we can import
let new_kmers: Vec<_> = other
.hash_to_kmer
.as_ref()
.unwrap()
.par_iter()
.map(|(&hash, kmer)| (hash, kmer.clone()))
.collect();

for (hash, kmer) in new_kmers {
if self
.hash_to_kmer
.as_mut()
.unwrap()
.insert(hash, kmer)
.is_none()
{
new_keys_added.fetch_add(1, Ordering::Relaxed);
}
}
} else {
// Warning: incoming table doesn't store kmers
eprintln!("Warning: Incoming table does not store k-mers, but target table does. K-mer information for new hashes will be missing.");
}
}

// Get final counts
let total_added = total_counts_added.load(Ordering::Relaxed);
let new_keys = new_keys_added.load(Ordering::Relaxed);

// Print summary
println!("Added {} k-mer counts to the table", total_added);
if self.store_kmers {
println!("Added {} new keys to hash_to_kmer table", new_keys);
}

Ok((total_added, new_keys))
}

/// Subtract counts from another KmerCountTable from this one.
///
/// # Arguments
///
/// * `other` - The KmerCountTable to subtract
///
/// # Returns
///
/// Returns a PyResult with the number of k-mer counts removed from the table
#[pyo3(signature = (other))]
pub fn subtract(&mut self, other: &KmerCountTable) -> PyResult<u64> {
// Check if ksizes match
if self.ksize != other.ksize {
return Err(PyValueError::new_err(
"KmerCountTables must have the same ksize",
));
}

// Use atomic counter for thread-safe updates
let total_counts_removed = AtomicU64::new(0);

// Create a vector to store updates
let updates: Vec<_> = other
.counts
.par_iter()
.filter_map(|(&hash, &count)| {
// Only include hashes that exist in self.counts
self.counts
.get(&hash)
.map(|&self_count| (hash, count, self_count))
})
.collect();

// Apply updates sequentially
for (hash, other_count, self_count) in updates {
let new_count = if self_count > other_count {
self_count - other_count
} else {
0
};

let removed = self_count - new_count;
total_counts_removed.fetch_add(removed, Ordering::Relaxed);

if new_count == 0 {
self.counts.remove(&hash);
} else {
self.counts.insert(hash, new_count);
}
}

// Update consumed bases
// Ensure consumed doesn't go negative
self.consumed = self.consumed.saturating_sub(other.consumed);

// Get final count of removed k-mers
let total_removed = total_counts_removed.load(Ordering::Relaxed);

// Print summary
println!("Removed {} k-mer counts from the table", total_removed);

Ok(total_removed)
}
}

#[pyclass]
Expand Down

0 comments on commit 5942485

Please sign in to comment.