Skip to content

Commit

Permalink
parallelize hash_to_kmer update
Browse files Browse the repository at this point in the history
  • Loading branch information
Adamtaranto committed Oct 14, 2024
1 parent f79f64e commit 2571e49
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::collections::hash_map::IntoIter;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{BufReader, BufWriter, Write};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{
atomic::{AtomicU64, Ordering},
Mutex,
};
//use std::path::Path;

// External crate imports
Expand Down Expand Up @@ -774,9 +777,18 @@ impl KmerCountTable {
///
/// Returns a PyResult with a tuple containing:
/// * The number of k-mer counts added
/// * The number of new keys added to hash_to_kmer (if store_kmers is true)
/// * The number of new keys added
#[pyo3(signature = (other))]
pub fn add(&mut self, other: &KmerCountTable) -> PyResult<(u64, u64)> {
// TODO: Borrow error raised when initialising "&mut self, other: &KmerCountTable"
// on same object. Below doesn't get a chance to run.
// Check if the other table is the same object as self
//if std::ptr::eq(self, other) {
// return Err(PyValueError::new_err(
// "Cannot add KmerCountTable to itself.",
// ));
//}

// Check if ksizes match
if self.ksize != other.ksize {
return Err(PyValueError::new_err(
Expand Down Expand Up @@ -813,21 +825,19 @@ impl KmerCountTable {
if self.store_kmers {
if other.store_kmers {
// Both tables have store_kmers = true, so we can import
let new_kmers: Vec<_> = other
let hash_to_kmer_mutex = Mutex::new(self.hash_to_kmer.as_mut().unwrap());

other
.hash_to_kmer
.as_ref()
.unwrap()
.par_iter()
.map(|(&hash, kmer)| (hash, kmer.clone()))
.collect();

for (hash, kmer) in new_kmers {
self.hash_to_kmer
.as_mut()
.unwrap()
.entry(hash)
.or_insert(kmer);
}
.for_each(|(&hash, kmer)| {
let mut hash_to_kmer_lock = hash_to_kmer_mutex.lock().unwrap();
hash_to_kmer_lock
.entry(hash)
.or_insert_with(|| kmer.clone());
});
} else {
// Warning: incoming table doesn't store kmers
eprintln!("Warning: Incoming table does not store k-mers, but target table does. K-mer information for new hashes will be missing.");
Expand All @@ -843,7 +853,7 @@ impl KmerCountTable {
println!("Added {} new keys to the table", new_keys);

Ok((total_added, new_keys))
}
}
}

#[pyclass]
Expand Down

0 comments on commit 2571e49

Please sign in to comment.