Skip to content

Commit

Permalink
start working on parallel table construction from sequence chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
ctb committed Oct 28, 2024
1 parent 667610b commit a2794cd
Showing 1 changed file with 93 additions and 5 deletions.
98 changes: 93 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ struct KmerCountTable {
hash_to_kmer: Option<HashMap<u64, String>>,
}

// CTB: convert to just a closure.
fn _do_consume(
seq: &str,
start: usize,
end: usize,
ksize: u8,
store_kmers: bool,
skip_bad_kmers: bool,
) -> Option<KmerCountTable> {
let mut t = KmerCountTable::new(ksize, store_kmers);

let subseq = &seq[start..end];
t._consume(subseq, skip_bad_kmers)
.expect("fail in sub consume");
Some(t)
}

#[pymethods]
impl KmerCountTable {
/// Constructor for KmerCountTable
Expand Down Expand Up @@ -542,6 +559,10 @@ impl KmerCountTable {
// exit with error.
#[pyo3(signature = (seq, skip_bad_kmers=true))]
pub fn consume(&mut self, seq: String, skip_bad_kmers: bool) -> PyResult<u64> {
self._consume(seq.as_str(), skip_bad_kmers)
}

fn _consume(&mut self, seq: &str, skip_bad_kmers: bool) -> PyResult<u64> {
// Incoming seq len
let new_len = seq.len();
// Init tally for consumed kmers
Expand Down Expand Up @@ -605,6 +626,73 @@ impl KmerCountTable {
Ok(n)
}

#[pyo3(signature = (seq, chunk_size, skip_bad_kmers=true))]
pub fn parallel_consume(
&mut self,
seq: String,
chunk_size: u64,
skip_bad_kmers: bool,
) -> PyResult<u64> {
// figure out the number of chunks, given the desired chunk size.
let seq_len = seq.len() as u64;
let mut num_chunks: u64 = seq_len / chunk_size;

let mut final_chunk: bool = false;
if seq_len % chunk_size > 0 {
num_chunks = num_chunks - 1;
final_chunk = true;
}

// build a vec of (start, end) pairs.
let mut coord_pairs: Vec<(u64, u64)> = vec![];

for i in 0..num_chunks {
let start = i * chunk_size;
let end = (i + 1) * chunk_size;
coord_pairs.push((start, end));
}
if final_chunk {
coord_pairs.push((num_chunks * chunk_size, seq_len));
}

eprintln!("chunk size: {}, num chunks: {}", chunk_size, num_chunks);
eprintln!("{:?}", coord_pairs);

// create reference to seq
let s = seq.as_str();

// build KmerCountTables in parallel
let tables: Vec<KmerCountTable> = coord_pairs
.par_iter()
.filter_map(|(start, end)| {
_do_consume(
s,
*start as usize,
*end as usize,
self.ksize,
self.store_kmers,
skip_bad_kmers,
)
})
.collect();

// now, merge the tables.
let mut total_consumed = 0;
for t in tables.into_iter() {
self.counts.extend(t.counts);

if self.store_kmers {
let my_hash_to_kmer = self.hash_to_kmer.as_mut().unwrap();
let t_hash_to_kmer = t.hash_to_kmer.expect("hash_to_kmer is None!?");
my_hash_to_kmer.extend(t_hash_to_kmer);
}
total_consumed += t.consumed;
}
self.consumed = total_consumed;

Ok(total_consumed)
}

// Helper method to get hash set of k-mers
fn hash_set(&self) -> HashSet<u64> {
self.counts.keys().cloned().collect()
Expand Down Expand Up @@ -688,7 +776,7 @@ impl KmerCountTable {
let mut v: Vec<(String, u64)> = vec![];

// Create the iterator
let mut iter = KmersAndHashesIter::new(seq, self.ksize as usize, skip_bad_kmers);
let mut iter = KmersAndHashesIter::new(seq.as_str(), self.ksize as usize, skip_bad_kmers);

// Collect the k-mers and their hashes
while let Some(result) = iter.next() {
Expand Down Expand Up @@ -778,7 +866,7 @@ impl KmerCountTableIterator {
}

pub struct KmersAndHashesIter {
seq: String, // The sequence to iterate over
seq: String, // The sequence to iterate over
seq_rc: String, // reverse complement sequence
ksize: usize, // K-mer size
pos: usize, // Current position in the sequence
Expand All @@ -788,9 +876,9 @@ pub struct KmersAndHashesIter {
}

impl KmersAndHashesIter {
pub fn new(seq: String, ksize: usize, skip_bad_kmers: bool) -> Self {
pub fn new(seq: &str, ksize: usize, skip_bad_kmers: bool) -> Self {
let seq = seq.to_ascii_uppercase(); // Ensure uppercase for uniformity
let seqb = seq.as_bytes().to_vec(); // Convert to bytes for hashing
let seqb = seq.as_bytes().to_vec(); // Convert to bytes for revcomp
let seqb_rc = revcomp(&seqb);
let seq_rc = std::str::from_utf8(&seqb_rc)
.expect("invalid utf-8 sequence for rev comp")
Expand Down Expand Up @@ -833,7 +921,7 @@ impl Iterator for KmersAndHashesIter {

// Extract the current k-mer and its reverse complement
let substr = &self.seq[start..start + ksize];
let substr_rc = &self.seq_rc[rpos..rpos+ksize];
let substr_rc = &self.seq_rc[rpos..rpos + ksize];

// Get the next hash value from the hasher
let hashval = self.hasher.next().expect("should not run out of hashes");
Expand Down

0 comments on commit a2794cd

Please sign in to comment.