diff --git a/src/lib.rs b/src/lib.rs index 9d08d59..fb122ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,23 @@ struct KmerCountTable { hash_to_kmer: Option>, } +// CTB: convert to just a closure. +fn _do_consume( + seq: &str, + start: usize, + end: usize, + ksize: u8, + store_kmers: bool, + skip_bad_kmers: bool, +) -> Option { + let mut t = KmerCountTable::new(ksize, store_kmers); + + let subseq = &seq[start..end]; + t._consume(subseq, skip_bad_kmers) + .expect("fail in sub consume"); + Some(t) +} + #[pymethods] impl KmerCountTable { /// Constructor for KmerCountTable @@ -542,6 +559,10 @@ impl KmerCountTable { // exit with error. #[pyo3(signature = (seq, skip_bad_kmers=true))] pub fn consume(&mut self, seq: String, skip_bad_kmers: bool) -> PyResult { + self._consume(seq.as_str(), skip_bad_kmers) + } + + fn _consume(&mut self, seq: &str, skip_bad_kmers: bool) -> PyResult { // Incoming seq len let new_len = seq.len(); // Init tally for consumed kmers @@ -605,6 +626,73 @@ impl KmerCountTable { Ok(n) } + #[pyo3(signature = (seq, chunk_size, skip_bad_kmers=true))] + pub fn parallel_consume( + &mut self, + seq: String, + chunk_size: u64, + skip_bad_kmers: bool, + ) -> PyResult { + // figure out the number of chunks, given the desired chunk size. + let seq_len = seq.len() as u64; + let mut num_chunks: u64 = seq_len / chunk_size; + + let mut final_chunk: bool = false; + if seq_len % chunk_size > 0 { + num_chunks = num_chunks - 1; + final_chunk = true; + } + + // build a vec of (start, end) pairs. + let mut coord_pairs: Vec<(u64, u64)> = vec![]; + + for i in 0..num_chunks { + let start = i * chunk_size; + let end = (i + 1) * chunk_size; + coord_pairs.push((start, end)); + } + if final_chunk { + coord_pairs.push((num_chunks * chunk_size, seq_len)); + } + + eprintln!("chunk size: {}, num chunks: {}", chunk_size, num_chunks); + eprintln!("{:?}", coord_pairs); + + // create reference to seq + let s = seq.as_str(); + + // build KmerCountTables in parallel + let tables: Vec = coord_pairs + .par_iter() + .filter_map(|(start, end)| { + _do_consume( + s, + *start as usize, + *end as usize, + self.ksize, + self.store_kmers, + skip_bad_kmers, + ) + }) + .collect(); + + // now, merge the tables. + let mut total_consumed = 0; + for t in tables.into_iter() { + self.counts.extend(t.counts); + + if self.store_kmers { + let my_hash_to_kmer = self.hash_to_kmer.as_mut().unwrap(); + let t_hash_to_kmer = t.hash_to_kmer.expect("hash_to_kmer is None!?"); + my_hash_to_kmer.extend(t_hash_to_kmer); + } + total_consumed += t.consumed; + } + self.consumed = total_consumed; + + Ok(total_consumed) + } + // Helper method to get hash set of k-mers fn hash_set(&self) -> HashSet { self.counts.keys().cloned().collect() @@ -688,7 +776,7 @@ impl KmerCountTable { let mut v: Vec<(String, u64)> = vec![]; // Create the iterator - let mut iter = KmersAndHashesIter::new(seq, self.ksize as usize, skip_bad_kmers); + let mut iter = KmersAndHashesIter::new(seq.as_str(), self.ksize as usize, skip_bad_kmers); // Collect the k-mers and their hashes while let Some(result) = iter.next() { @@ -778,7 +866,7 @@ impl KmerCountTableIterator { } pub struct KmersAndHashesIter { - seq: String, // The sequence to iterate over + seq: String, // The sequence to iterate over seq_rc: String, // reverse complement sequence ksize: usize, // K-mer size pos: usize, // Current position in the sequence @@ -788,9 +876,9 @@ pub struct KmersAndHashesIter { } impl KmersAndHashesIter { - pub fn new(seq: String, ksize: usize, skip_bad_kmers: bool) -> Self { + pub fn new(seq: &str, ksize: usize, skip_bad_kmers: bool) -> Self { let seq = seq.to_ascii_uppercase(); // Ensure uppercase for uniformity - let seqb = seq.as_bytes().to_vec(); // Convert to bytes for hashing + let seqb = seq.as_bytes().to_vec(); // Convert to bytes for revcomp let seqb_rc = revcomp(&seqb); let seq_rc = std::str::from_utf8(&seqb_rc) .expect("invalid utf-8 sequence for rev comp") @@ -833,7 +921,7 @@ impl Iterator for KmersAndHashesIter { // Extract the current k-mer and its reverse complement let substr = &self.seq[start..start + ksize]; - let substr_rc = &self.seq_rc[rpos..rpos+ksize]; + let substr_rc = &self.seq_rc[rpos..rpos + ksize]; // Get the next hash value from the hasher let hashval = self.hasher.next().expect("should not run out of hashes");