diff --git a/src/lib.rs b/src/lib.rs index 0ca4c29..9971343 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -509,53 +509,16 @@ impl KmerCountTable { seq: String, skip_bad_kmers: bool, ) -> PyResult> { - // TODO: optimize RC calculation - // TODO: confirm that there are no more hashes left? unreachable? - let seq = seq.to_ascii_uppercase(); - let seqb = seq.as_bytes(); - - let mut hasher = SeqToHashes::new( - seqb, - self.ksize.into(), - skip_bad_kmers, - false, - HashFunctions::Murmur64Dna, - 42, - ); + let mut v: Vec<(String, u64)> = vec![]; - let ksize = self.ksize as usize; - let end: usize = seq.len() - ksize + 1; + // Create the iterator + let mut iter = KmersAndHashesIter::new(seq, self.ksize as usize, skip_bad_kmers); - let mut v: Vec<(String, u64)> = vec![]; - for start in 0..end { - let substr = &seq[start..start + ksize]; - // CTB: this calculates RC each time, instead of doing so - // using a sliding window. It's easy and works, so I'm - // starting here :). - let substr_b_rc = revcomp(&seqb[start..start + ksize]); - let substr_rc = - std::str::from_utf8(&substr_b_rc).expect("invalid utf-8 sequence for rev comp"); - let hashval = hasher.next().expect("should not run out of hashes"); - - // Three options: - // * good kmer, all is well, store canonical k-mer and hashval; - // * bad k-mer allowed by skip_bad_kmers, and signaled by - // hashval == 0): return empty string & 0; - // * bad k-mer not allowed, raise error - if let Ok(hashval) = hashval { - if hashval > 0 { - let canonical_kmer = if substr < substr_rc { - substr - } else { - substr_rc - }; - v.push((canonical_kmer.to_string(), hashval)); - } else { - v.push(("".to_owned(), 0)); - } - } else { - let msg = format!("bad k-mer at position {}: {}", start, substr); - return Err(PyValueError::new_err(msg)); + // Collect the k-mers and their hashes + while let Some(result) = iter.next() { + match result { + Ok((kmer, hash)) => v.push((kmer, hash)), + Err(e) => return Err(e), } } @@ -638,6 +601,104 @@ impl KmerCountTableIterator { } } +pub struct KmersAndHashesIter { + seq: String, // The sequence to iterate over + seqb: Vec, // Sequence bytes + ksize: usize, // K-mer size + pos: usize, // Current position in the sequence + end: usize, // The end position for k-mer extraction + hasher: SeqToHashes, // Iterator for generating hashes + skip_bad_kmers: bool, // Flag to skip bad k-mers +} + +impl KmersAndHashesIter { + pub fn new(seq: String, ksize: usize, skip_bad_kmers: bool) -> Self { + let seq = seq.to_ascii_uppercase(); // Ensure uppercase for uniformity + let seqb = seq.as_bytes().to_vec(); // Convert to bytes for hashing + let end = seq.len() - ksize + 1; // Calculate the endpoint for k-mer extraction + + let hasher = SeqToHashes::new( + &seqb, + ksize.into(), + skip_bad_kmers, + false, // Other flags, e.g., reverse complement + HashFunctions::Murmur64Dna, + 42, // Seed for hashing + ); + + Self { + seq, + seqb, + ksize, + pos: 0, // Start at the beginning of the sequence + end, + hasher, + skip_bad_kmers, + } + } +} + +impl Iterator for KmersAndHashesIter { + type Item = PyResult<(String, u64)>; + + fn next(&mut self) -> Option { + // Check if we've reached the end of the sequence + if self.pos >= self.end { + return None; + } + + let start = self.pos; + let ksize = self.ksize; + + // Extract the current k-mer and its reverse complement + let substr = &self.seq[start..start + ksize]; + // CTB: this calculates RC each time, instead of doing so + // using a sliding window. It's easy and works, so I'm + // starting here :). + let substr_b_rc = revcomp(&self.seqb[start..start + ksize]); + let substr_rc = + std::str::from_utf8(&substr_b_rc).expect("invalid utf-8 sequence for rev comp"); + + // Get the next hash value from the hasher + let hashval = self.hasher.next().expect("should not run out of hashes"); + + // Increment position for the next k-mer + self.pos += 1; + + // Handle hash value logic (similar to original function) + // Three options: + // * good kmer, all is well, store canonical k-mer and hashval; + // * bad k-mer allowed by skip_bad_kmers, and signaled by + // hashval == 0): return empty string & 0; + // * bad k-mer not allowed, raise error + if let Ok(hashval) = hashval { + if hashval > 0 { + // Select the canonical k-mer (lexicographically smaller between forward and reverse complement) + let canonical_kmer = if substr < substr_rc { + substr + } else { + substr_rc + }; + Some(Ok((canonical_kmer.to_string(), hashval))) + } else { + // Use the `skip_bad_kmers` flag here to nix unused warning. + // TODO: Decide what to do about bad kmers. + if self.skip_bad_kmers { + // If the hash is 0, this is a bad k-mer + Some(Ok(("".to_string(), 0))) + } else { + // If an error occurs, return an error with the position and k-mer + let msg = format!("bad k-mer at position {}: {}", start, substr); + Some(Err(PyValueError::new_err(msg))) + } + } + } else { + let msg = format!("bad k-mer at position {}: {}", start, substr); + Some(Err(PyValueError::new_err(msg))) + } + } +} + // Python module definition #[pymodule] fn oxli(m: &Bound<'_, PyModule>) -> PyResult<()> {