diff --git a/src/core/src/signature.rs b/src/core/src/signature.rs index 84e9d91477..abffbc19f7 100644 --- a/src/core/src/signature.rs +++ b/src/core/src/signature.rs @@ -10,6 +10,7 @@ use std::path::Path; use std::str; use cfg_if::cfg_if; +use itertools::Itertools; #[cfg(feature = "parallel")] use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -171,7 +172,7 @@ pub enum ReadingFrame { len: usize, // len gives max_index for kmer iterator }, Protein { - fw: Vec, // Only forward frame + fw: Vec, len: usize, }, } @@ -198,58 +199,64 @@ impl std::fmt::Display for ReadingFrame { impl ReadingFrame { pub fn new_dna(sequence: &[u8]) -> Self { - let fw = sequence.to_vec(); - let rc = revcomp(sequence); + let fw = sequence.to_ascii_uppercase(); + let rc = revcomp(&fw); let len = sequence.len(); ReadingFrame::DNA { fw, rc, len } } pub fn new_protein(sequence: &[u8], dayhoff: bool, hp: bool) -> Self { + let seq = sequence.to_ascii_uppercase(); let fw: Vec = if dayhoff { - sequence.iter().map(|&aa| aa_to_dayhoff(aa)).collect() + seq.iter().map(|&aa| aa_to_dayhoff(aa)).collect() } else if hp { - sequence.iter().map(|&aa| aa_to_hp(aa)).collect() + seq.iter().map(|&aa| aa_to_hp(aa)).collect() } else { - sequence.to_vec() // protein, as-is. + seq }; let len = fw.len(); ReadingFrame::Protein { fw, len } } - pub fn new_skipmer(seq: &[u8], start: usize, m: usize, n: usize) -> Self { + pub fn new_skipmer(sequence: &[u8], start: usize, m: usize, n: usize) -> Self { + let seq = sequence.to_ascii_uppercase(); if start >= n { panic!("Skipmer frame number must be < n ({})", n); } - // Generate forward skipmer frame - let fw: Vec = seq - .iter() - .skip(start) - .enumerate() - .filter_map(|(i, &base)| if i % n < m { Some(base) } else { None }) - .collect(); + // do we need to round up? (+1) + let mut fw = Vec::with_capacity(((seq.len() * m) + 1) / n); + seq.iter().skip(start).enumerate().for_each(|(i, &base)| { + if i % n < m { + fw.push(base.to_ascii_uppercase()); + } + }); let len = fw.len(); let rc = revcomp(&fw); ReadingFrame::DNA { fw, rc, len } } + // this is the only one that doesn't uppercase in here b/c more efficient to uppercase externally :/ pub fn new_translated(sequence: &[u8], frame_number: usize, dayhoff: bool, hp: bool) -> Self { if frame_number > 2 { panic!("Frame number must be 0, 1, or 2"); } - // translate sequence - let fw: Vec = sequence + // Translate sequence into amino acids + let mut fw = Vec::with_capacity(sequence.len() / 3); + // NOTE: b/c of chunks(3), we only process full codons and ignore leftover bases (e.g. 1 or 2 at end of frame) + sequence .iter() - .cloned() - .skip(frame_number) // skip the initial bases for the frame - .take(sequence.len() - frame_number) // adjust length based on skipped bases - .collect::>() // collect the DNA subsequence - .chunks(3) // group into codons (triplets) - .filter_map(|codon| to_aa(codon, dayhoff, hp).ok()) // translate each codon - .flatten() // flatten the nested results into a single sequence - .collect(); + .skip(frame_number) // Skip the initial bases for the frame + .take(sequence.len() - frame_number) // Adjust length based on skipped bases + .chunks(3) // Group into codons (triplets) using itertools + .into_iter() + .filter_map(|chunk| { + let codon: Vec = chunk.cloned().collect(); // Collect the chunk into a Vec + to_aa(&codon, dayhoff, hp).ok() // Translate the codon + }) + .for_each(|aa| fw.extend(aa)); // Extend `fw` with amino acids let len = fw.len(); @@ -258,6 +265,7 @@ impl ReadingFrame { } /// Get the forward sequence. + #[inline] pub fn fw(&self) -> &[u8] { match self { ReadingFrame::DNA { fw, .. } => fw, @@ -266,6 +274,7 @@ impl ReadingFrame { } /// Get the reverse complement sequence (if DNA). + #[inline] pub fn rc(&self) -> &[u8] { match self { ReadingFrame::DNA { rc, .. } => rc, @@ -273,6 +282,7 @@ impl ReadingFrame { } } + #[inline] pub fn length(&self) -> usize { match self { ReadingFrame::DNA { len, .. } => *len, @@ -294,8 +304,9 @@ pub struct SeqToHashes { force: bool, seed: u64, frames: Vec, - frame_index: usize, // Index of the current frame - kmer_index: usize, // Current k-mer index within the frame + frame_index: usize, // Index of the current frame + kmer_index: usize, // Current k-mer index within the frame + last_position_check: usize, // Index of last base we validated } impl SeqToHashes { @@ -314,19 +325,17 @@ impl SeqToHashes { ksize = k_size / 3; } - // uppercase the sequence. this clones the data bc &[u8] is immutable? - // TODO: could we avoid this by changing revcomp/VALID/etc? - let sequence = seq.to_ascii_uppercase(); - // Generate frames based on sequence type and hash function - let frames = if is_protein { - Self::protein_frames(&sequence, &hash_function) + let frames = if hash_function.dna() { + Self::dna_frames(&seq) + } else if is_protein { + Self::protein_frames(&seq, &hash_function) } else if hash_function.protein() || hash_function.dayhoff() || hash_function.hp() { - Self::translated_frames(&sequence, &hash_function) + Self::translated_frames(&seq, &hash_function) } else if hash_function.skipm1n3() || hash_function.skipm2n3() { - Self::skipmer_frames(&sequence, &hash_function, ksize) + Self::skipmer_frames(&seq, &hash_function, ksize) } else { - Self::dna_frames(&sequence) + unimplemented!(); }; SeqToHashes { @@ -336,6 +345,7 @@ impl SeqToHashes { frames, frame_index: 0, kmer_index: 0, + last_position_check: 0, } } @@ -355,12 +365,14 @@ impl SeqToHashes { /// generate translated frames: 6 protein frames fn translated_frames(seq: &[u8], hash_function: &HashFunctions) -> Vec { - let revcomp_sequence = revcomp(seq); + // since we need to revcomp BEFORE making ReadingFrames, uppercase the sequence here + let sequence = seq.to_ascii_uppercase(); + let revcomp_sequence = revcomp(&sequence); (0..3) .flat_map(|frame_number| { vec![ ReadingFrame::new_translated( - seq, + &sequence, frame_number, hash_function.dayhoff(), hash_function.hp(), @@ -398,59 +410,6 @@ impl SeqToHashes { fn out_of_bounds(&self, frame: &ReadingFrame) -> bool { self.kmer_index + self.k_size > frame.length() } - - // check all bases are valid - fn validate_dna_kmer(&self, kmer: &[u8]) -> Result { - for &nt in kmer { - if !VALID[nt as usize] { - if self.force { - // Return `false` to indicate invalid k-mer, but do not error out - return Ok(false); - } else { - return Err(Error::InvalidDNA { - message: String::from_utf8_lossy(kmer).to_string(), - }); - } - } - } - Ok(true) // All bases are valid - } - - /// Process a DNA k-mer, including canonicalization and validation - fn dna_hash(&self, frame: &ReadingFrame) -> Result { - let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size]; - let rc = frame.rc(); - - // Validate the k-mer. Skip if invalid and force is true - if !self.validate_dna_kmer(kmer)? { - return Ok(0); // Skip invalid k-mer - } - - // For a ksize = 3, and a sequence AGTCGT (len = 6): - // +-+---------+---------------+-------+ - // seq RC |i|i + ksize|len - ksize - i|len - i| - // AGTCGT ACGACT +-+---------+---------------+-------+ - // +-> +-> |0| 2 | 3 | 6 | - // +-> +-> |1| 3 | 2 | 5 | - // +-> +-> |2| 4 | 1 | 4 | - // +-> +-> |3| 5 | 0 | 3 | - // +-+---------+---------------+-------+ - // (leaving this table here because I had to draw to - // get the indices correctly) - let reverse_index = frame.length() - self.k_size - self.kmer_index; - let krc = &rc[reverse_index..reverse_index + self.k_size]; - - // Compute canonical hash - let canonical_kmer = std::cmp::min(kmer, krc); - let hash = crate::_hash_murmur(canonical_kmer, self.seed); - - Ok(hash) - } - - fn protein_hash(&self, frame: &ReadingFrame) -> u64 { - let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size]; - crate::_hash_murmur(kmer, self.seed) // build and return hash - } } impl Iterator for SeqToHashes { @@ -464,22 +423,60 @@ impl Iterator for SeqToHashes { if self.out_of_bounds(frame) { self.frame_index += 1; self.kmer_index = 0; // Reset for the next frame + self.last_position_check = 0; continue; } - // Delegate to DNA or protein processing let result = match frame { - ReadingFrame::DNA { .. } => match self.dna_hash(frame) { - Ok(hash) => Ok(hash), // Valid hash - Err(err) => Err(err), // Error - }, - ReadingFrame::Protein { .. } => Ok(self.protein_hash(frame)), + ReadingFrame::DNA { .. } => { + let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size]; + let rc = frame.rc(); + + // Validate k-mer bases + for j in std::cmp::max(self.kmer_index, self.last_position_check) + ..self.kmer_index + self.k_size + { + if !VALID[frame.fw()[j] as usize] { + if !self.force { + // Return an error if force is false + return Some(Err(Error::InvalidDNA { + message: String::from_utf8(kmer.to_vec()).unwrap(), + })); + } else { + // Skip the invalid k-mer + self.kmer_index += 1; + return Some(Ok(0)); + } + } + self.last_position_check += 1; + } + + // Compute canonical hash + // For a ksize = 3, and a sequence AGTCGT (len = 6): + // +-+---------+---------------+-------+ + // seq RC |i|i + ksize|len - ksize - i|len - i| + // AGTCGT ACGACT +-+---------+---------------+-------+ + // +-> +-> |0| 2 | 3 | 6 | + // +-> +-> |1| 3 | 2 | 5 | + // +-> +-> |2| 4 | 1 | 4 | + // +-> +-> |3| 5 | 0 | 3 | + // +-+---------+---------------+-------+ + // (leaving this table here because I had to draw to + // get the indices correctly) + let krc = &rc[frame.length() - self.k_size - self.kmer_index + ..frame.length() - self.kmer_index]; + let hash = crate::_hash_murmur(std::cmp::min(kmer, krc), self.seed); + Ok(hash) + } + ReadingFrame::Protein { .. } => { + let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size]; + Ok(crate::_hash_murmur(kmer, self.seed)) + } }; - self.kmer_index += 1; // Advance k-mer index + self.kmer_index += 1; // Advance k-mer index for valid k-mers return Some(result); } - None // No more frames or k-mers } } diff --git a/tests/test_sourmash_sketch.py b/tests/test_sourmash_sketch.py index 88332b1945..b95b081a75 100644 --- a/tests/test_sourmash_sketch.py +++ b/tests/test_sourmash_sketch.py @@ -345,6 +345,7 @@ def test_protein_override_bad_rust_foo(): siglist = factory() assert len(siglist) == 1 sig = siglist[0] + print(sig.minhash.ksize) # try adding something testdata1 = utils.get_test_data("ecoli.faa") @@ -354,7 +355,12 @@ def test_protein_override_bad_rust_foo(): with pytest.raises(ValueError) as exc: sig.add_protein(record.sequence) - assert 'Invalid hash function: "DNA"' in str(exc) + # assert 'Invalid hash function: "DNA"' in str(exc) + + # this case now ends up in the "DNA" section of SeqToHashes, + # so we run into the invalid k-mer error + # instead of invalid Hash Function. + assert "invalid DNA character in input k-mer: MRVLKFGGTS" in str(exc) def test_dayhoff_defaults():