From 2571e490bcef99c98698564578088aa4ebd35eec Mon Sep 17 00:00:00 2001 From: Adam Taranto Date: Mon, 14 Oct 2024 15:13:58 +1100 Subject: [PATCH] parallelize hash_to_kmer update --- src/lib.rs | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 137632a..71810e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,10 @@ use std::collections::hash_map::IntoIter; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{BufReader, BufWriter, Write}; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Mutex, +}; //use std::path::Path; // External crate imports @@ -774,9 +777,18 @@ impl KmerCountTable { /// /// Returns a PyResult with a tuple containing: /// * The number of k-mer counts added - /// * The number of new keys added to hash_to_kmer (if store_kmers is true) + /// * The number of new keys added #[pyo3(signature = (other))] pub fn add(&mut self, other: &KmerCountTable) -> PyResult<(u64, u64)> { + // TODO: Borrow error raised when initialising "&mut self, other: &KmerCountTable" + // on same object. Below doesn't get a chance to run. + // Check if the other table is the same object as self + //if std::ptr::eq(self, other) { + // return Err(PyValueError::new_err( + // "Cannot add KmerCountTable to itself.", + // )); + //} + // Check if ksizes match if self.ksize != other.ksize { return Err(PyValueError::new_err( @@ -813,21 +825,19 @@ impl KmerCountTable { if self.store_kmers { if other.store_kmers { // Both tables have store_kmers = true, so we can import - let new_kmers: Vec<_> = other + let hash_to_kmer_mutex = Mutex::new(self.hash_to_kmer.as_mut().unwrap()); + + other .hash_to_kmer .as_ref() .unwrap() .par_iter() - .map(|(&hash, kmer)| (hash, kmer.clone())) - .collect(); - - for (hash, kmer) in new_kmers { - self.hash_to_kmer - .as_mut() - .unwrap() - .entry(hash) - .or_insert(kmer); - } + .for_each(|(&hash, kmer)| { + let mut hash_to_kmer_lock = hash_to_kmer_mutex.lock().unwrap(); + hash_to_kmer_lock + .entry(hash) + .or_insert_with(|| kmer.clone()); + }); } else { // Warning: incoming table doesn't store kmers eprintln!("Warning: Incoming table does not store k-mers, but target table does. K-mer information for new hashes will be missing."); @@ -843,7 +853,7 @@ impl KmerCountTable { println!("Added {} new keys to the table", new_keys); Ok((total_added, new_keys)) - } + } } #[pyclass]