diff --git a/src/core/src/cmd.rs b/src/core/src/cmd.rs index 115d2672af..d5bfb092ba 100644 --- a/src/core/src/cmd.rs +++ b/src/core/src/cmd.rs @@ -44,7 +44,11 @@ pub struct ComputeParameters { #[getset(get_copy = "pub", set = "pub")] #[builder(default = false)] - skipmer: bool, + skipm1n3: bool, + + #[getset(get_copy = "pub", set = "pub")] + #[builder(default = false)] + skipm2n3: bool, #[getset(get_copy = "pub", set = "pub")] #[builder(default = false)] @@ -169,12 +173,29 @@ pub fn build_template(params: &ComputeParameters) -> Vec { )); } - if params.skipmer { + if params.skipm1n3 { + ksigs.push(Sketch::LargeMinHash( + KmerMinHashBTree::builder() + .num(params.num_hashes) + .ksize(*k) + .hash_function(HashFunctions::Murmur64Skipm1n3) + .max_hash(max_hash) + .seed(params.seed) + .abunds(if params.track_abundance { + Some(Default::default()) + } else { + None + }) + .build(), + )); + } + + if params.skipm2n3 { ksigs.push(Sketch::LargeMinHash( KmerMinHashBTree::builder() .num(params.num_hashes) .ksize(*k) - .hash_function(HashFunctions::Murmur64Skipmer) + .hash_function(HashFunctions::Murmur64Skipm2n3) .max_hash(max_hash) .seed(params.seed) .abunds(if params.track_abundance { diff --git a/src/core/src/encodings.rs b/src/core/src/encodings.rs index 0aba272a5c..56077c80e1 100644 --- a/src/core/src/encodings.rs +++ b/src/core/src/encodings.rs @@ -31,7 +31,8 @@ pub enum HashFunctions { Murmur64Protein, Murmur64Dayhoff, Murmur64Hp, - Murmur64Skipmer, + Murmur64Skipm1n3, + Murmur64Skipm2n3, Custom(String), } @@ -51,8 +52,13 @@ impl HashFunctions { pub fn hp(&self) -> bool { *self == HashFunctions::Murmur64Hp } - pub fn skipmer(&self) -> bool { - *self == HashFunctions::Murmur64Skipmer + + pub fn skipm1n3(&self) -> bool { + *self == HashFunctions::Murmur64Skipm1n3 + } + + pub fn skipm2n3(&self) -> bool { + *self == HashFunctions::Murmur64Skipm2n3 } } @@ -66,7 +72,8 @@ impl std::fmt::Display for HashFunctions { HashFunctions::Murmur64Protein => "protein", HashFunctions::Murmur64Dayhoff => "dayhoff", HashFunctions::Murmur64Hp => "hp", - HashFunctions::Murmur64Skipmer => "skipmer", + HashFunctions::Murmur64Skipm1n3 => "skipm1n3", + HashFunctions::Murmur64Skipm2n3 => "skipm2n3", HashFunctions::Custom(v) => v, } ) @@ -82,7 +89,8 @@ impl TryFrom<&str> for HashFunctions { "dayhoff" => Ok(HashFunctions::Murmur64Dayhoff), "hp" => Ok(HashFunctions::Murmur64Hp), "protein" => Ok(HashFunctions::Murmur64Protein), - "skipmer" => Ok(HashFunctions::Murmur64Skipmer), + "skipm1n3" => Ok(HashFunctions::Murmur64Skipm1n3), + "skipm2n3" => Ok(HashFunctions::Murmur64Skipm2n3), v => unimplemented!("{v}"), } } diff --git a/src/core/src/signature.rs b/src/core/src/signature.rs index 93580170a2..5cc60ca299 100644 --- a/src/core/src/signature.rs +++ b/src/core/src/signature.rs @@ -10,6 +10,7 @@ use std::path::Path; use std::str; use cfg_if::cfg_if; +use itertools::Itertools; #[cfg(feature = "parallel")] use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -163,27 +164,149 @@ impl SigsTrait for Sketch { } } -// Iterator for converting sequence to hashes +#[derive(Debug, Clone)] +pub enum ReadingFrame { + DNA { + fw: Vec, + rc: Vec, + len: usize, // len gives max_index for kmer iterator + }, + Protein { + fw: Vec, + len: usize, + }, +} + +impl std::fmt::Display for ReadingFrame { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ReadingFrame::DNA { fw, rc, len } => { + let fw_str = String::from_utf8_lossy(fw).to_string(); + let rc_str = String::from_utf8_lossy(rc).to_string(); + write!( + f, + "Type: DNA ({}bp), Forward: {}, Reverse Complement: {}", + len, fw_str, rc_str + ) + } + ReadingFrame::Protein { fw, len } => { + let fw_str = String::from_utf8_lossy(fw).to_string(); + write!(f, "Type: Protein ({}aa), Forward: {}", len, fw_str) + } + } + } +} + +impl ReadingFrame { + pub fn new_dna(sequence: &[u8]) -> Self { + let fw = sequence.to_ascii_uppercase(); + let rc = revcomp(&fw); + let len = sequence.len(); + ReadingFrame::DNA { fw, rc, len } + } + + pub fn new_protein(sequence: &[u8], dayhoff: bool, hp: bool) -> Self { + let seq = sequence.to_ascii_uppercase(); + let fw: Vec = if dayhoff { + seq.iter().map(|&aa| aa_to_dayhoff(aa)).collect() + } else if hp { + seq.iter().map(|&aa| aa_to_hp(aa)).collect() + } else { + seq + }; + + let len = fw.len(); + ReadingFrame::Protein { fw, len } + } + + pub fn new_skipmer(sequence: &[u8], start: usize, m: usize, n: usize) -> Self { + let seq = sequence.to_ascii_uppercase(); + if start >= n { + panic!("Skipmer frame number must be < n ({})", n); + } + // do we need to round up? (+1) + let mut fw = Vec::with_capacity(((seq.len() * m) + 1) / n); + seq.iter().skip(start).enumerate().for_each(|(i, &base)| { + if i % n < m { + fw.push(base.to_ascii_uppercase()); + } + }); + + let len = fw.len(); + let rc = revcomp(&fw); + ReadingFrame::DNA { fw, rc, len } + } + + // this is the only one that doesn't uppercase in here b/c more efficient to uppercase externally :/ + pub fn new_translated(sequence: &[u8], frame_number: usize, dayhoff: bool, hp: bool) -> Self { + if frame_number > 2 { + panic!("Frame number must be 0, 1, or 2"); + } + + // Translate sequence into amino acids + let mut fw = Vec::with_capacity(sequence.len() / 3); + // NOTE: b/c of chunks(3), we only process full codons and ignore leftover bases (e.g. 1 or 2 at end of frame) + sequence + .iter() + .skip(frame_number) // Skip the initial bases for the frame + .take(sequence.len() - frame_number) // Adjust length based on skipped bases + .chunks(3) // Group into codons (triplets) using itertools + .into_iter() + .filter_map(|chunk| { + let codon: Vec = chunk.cloned().collect(); // Collect the chunk into a Vec + to_aa(&codon, dayhoff, hp).ok() // Translate the codon + }) + .for_each(|aa| fw.extend(aa)); // Extend `fw` with amino acids + + let len = fw.len(); + + // return protein reading frame + ReadingFrame::Protein { fw, len } + } + + /// Get the forward sequence. + #[inline] + pub fn fw(&self) -> &[u8] { + match self { + ReadingFrame::DNA { fw, .. } => fw, + ReadingFrame::Protein { fw, .. } => fw, + } + } + + /// Get the reverse complement sequence (if DNA). + #[inline] + pub fn rc(&self) -> &[u8] { + match self { + ReadingFrame::DNA { rc, .. } => rc, + _ => panic!("Reverse complement is only available for DNA frames"), + } + } + + #[inline] + pub fn length(&self) -> usize { + match self { + ReadingFrame::DNA { len, .. } => *len, + ReadingFrame::Protein { len, .. } => *len, + } + } + + /// Get the type of the frame as a string. + pub fn frame_type(&self) -> &'static str { + match self { + ReadingFrame::DNA { .. } => "DNA", + ReadingFrame::Protein { .. } => "Protein", + } + } +} + pub struct SeqToHashes { - sequence: Vec, - kmer_index: usize, k_size: usize, - max_index: usize, force: bool, - is_protein: bool, - hash_function: HashFunctions, seed: u64, - hashes_buffer: Vec, - - dna_configured: bool, - dna_rc: Vec, - dna_ksize: usize, - dna_len: usize, - dna_last_position_check: usize, - - prot_configured: bool, - aa_seq: Vec, - translate_iter_step: usize, + frames: Vec, + frame_index: usize, // Index of the current frame + kmer_index: usize, // Current k-mer index within the frame + last_position_check: usize, // Index of last base we validated } impl SeqToHashes { @@ -194,118 +317,141 @@ impl SeqToHashes { is_protein: bool, hash_function: HashFunctions, seed: u64, - ) -> SeqToHashes { + ) -> Self { let mut ksize: usize = k_size; - // Divide the kmer size by 3 if protein + // Adjust kmer size for protein-based hash functions if is_protein || hash_function.protein() || hash_function.dayhoff() || hash_function.hp() { ksize = k_size / 3; } - // By setting _max_index to 0, the iterator will return None and exit - let _max_index = if seq.len() >= ksize { - seq.len() - ksize + 1 + // Generate frames based on sequence type and hash function + let frames = if hash_function.dna() { + Self::dna_frames(seq) + } else if is_protein { + Self::protein_frames(seq, &hash_function) + } else if hash_function.protein() || hash_function.dayhoff() || hash_function.hp() { + Self::translated_frames(seq, &hash_function) + } else if hash_function.skipm1n3() || hash_function.skipm2n3() { + Self::skipmer_frames(seq, &hash_function, ksize) } else { - 0 + unimplemented!(); }; SeqToHashes { - // Here we convert the sequence to upper case - sequence: seq.to_ascii_uppercase(), k_size: ksize, - kmer_index: 0, - max_index: _max_index, force, - is_protein, - hash_function, seed, - hashes_buffer: Vec::with_capacity(1000), - dna_configured: false, - dna_rc: Vec::with_capacity(1000), - dna_ksize: 0, - dna_len: 0, - dna_last_position_check: 0, - prot_configured: false, - aa_seq: Vec::new(), - translate_iter_step: 0, + frames, + frame_index: 0, + kmer_index: 0, + last_position_check: 0, } } - fn validate_base(&self, base: u8, kmer: &[u8]) -> Option> { - if !VALID[base as usize] { - if !self.force { - return Some(Err(Error::InvalidDNA { - message: String::from_utf8(kmer.to_owned()).unwrap_or_default(), - })); - } else { - return Some(Ok(0)); // Skip this position if forced - } + /// generate frames from DNA: 1 DNA frame (fw+rc) + fn dna_frames(seq: &[u8]) -> Vec { + vec![ReadingFrame::new_dna(seq)] + } + + /// generate frames from protein: 1 protein frame + fn protein_frames(seq: &[u8], hash_function: &HashFunctions) -> Vec { + vec![ReadingFrame::new_protein( + seq, + hash_function.dayhoff(), + hash_function.hp(), + )] + } + + /// generate translated frames: 6 protein frames + fn translated_frames(seq: &[u8], hash_function: &HashFunctions) -> Vec { + // since we need to revcomp BEFORE making ReadingFrames, uppercase the sequence here + let sequence = seq.to_ascii_uppercase(); + let revcomp_sequence = revcomp(&sequence); + (0..3) + .flat_map(|frame_number| { + vec![ + ReadingFrame::new_translated( + &sequence, + frame_number, + hash_function.dayhoff(), + hash_function.hp(), + ), + ReadingFrame::new_translated( + &revcomp_sequence, + frame_number, + hash_function.dayhoff(), + hash_function.hp(), + ), + ] + }) + .collect() + } + + /// generate skipmer frames: 3 DNA frames (each with fw+rc) + fn skipmer_frames( + seq: &[u8], + hash_function: &HashFunctions, + ksize: usize, + ) -> Vec { + let (m, n) = if hash_function.skipm1n3() { + (1, 3) + } else { + (2, 3) + }; + if ksize < n { + unimplemented!() } - None // Base is valid, so return None to continue + (0..3) + .flat_map(|frame_number| vec![ReadingFrame::new_skipmer(seq, frame_number, m, n)]) + .collect() } -} -/* -Iterator that return a kmer hash for all modes except translate. -In translate mode: - - all the frames are processed at once and converted to hashes. - - all the hashes are stored in `hashes_buffer` - - after processing all the kmers, `translate_iter_step` is incremented - per iteration to iterate over all the indeces of the `hashes_buffer`. - - the iterator will die once `translate_iter_step` == length(hashes_buffer) -More info https://github.com/sourmash-bio/sourmash/pull/1946 -*/ + fn out_of_bounds(&self, frame: &ReadingFrame) -> bool { + self.kmer_index + self.k_size > frame.length() + } +} impl Iterator for SeqToHashes { type Item = Result; fn next(&mut self) -> Option { - if (self.kmer_index < self.max_index) || !self.hashes_buffer.is_empty() { - // Processing DNA or Translated DNA - if !self.is_protein { - // Setting the parameters only in the first iteration - if !self.dna_configured { - self.dna_ksize = self.k_size; - self.dna_len = self.sequence.len(); - if self.dna_len < self.dna_ksize - || (self.hash_function.protein() && self.dna_len < self.k_size * 3) - || (self.hash_function.dayhoff() && self.dna_len < self.k_size * 3) - || (self.hash_function.hp() && self.dna_len < self.k_size * 3) - || (self.hash_function.skipmer() - // add 1 to round up rather than down - && self.dna_len < (self.k_size + ((self.k_size + 1) / 2) - 1)) - { - return None; - } - // pre-calculate the reverse complement for the full sequence... - self.dna_rc = revcomp(&self.sequence); - self.dna_configured = true; - } + while self.frame_index < self.frames.len() { + let frame = &self.frames[self.frame_index]; + + // Do we need to move to the next frame? + if self.out_of_bounds(frame) { + self.frame_index += 1; + self.kmer_index = 0; // Reset for the next frame + self.last_position_check = 0; + continue; + } - // Processing DNA - if self.hash_function.dna() { - let kmer = &self.sequence[self.kmer_index..self.kmer_index + self.dna_ksize]; + let result = match frame { + ReadingFrame::DNA { .. } => { + let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size]; + let rc = frame.rc(); - // validate the bases - for j in std::cmp::max(self.kmer_index, self.dna_last_position_check) - ..self.kmer_index + self.dna_ksize + // Validate k-mer bases + for j in std::cmp::max(self.kmer_index, self.last_position_check) + ..self.kmer_index + self.k_size { - if !VALID[self.sequence[j] as usize] { + if !VALID[frame.fw()[j] as usize] { if !self.force { + // Return an error if force is false return Some(Err(Error::InvalidDNA { message: String::from_utf8(kmer.to_vec()).unwrap(), })); } else { + // Skip the invalid k-mer self.kmer_index += 1; - // Move the iterator to the next step return Some(Ok(0)); } } - self.dna_last_position_check += 1; + self.last_position_check += 1; } - // ... and then while moving the k-mer window forward for the sequence - // we move another window backwards for the RC. + // Compute canonical hash // For a ksize = 3, and a sequence AGTCGT (len = 6): // +-+---------+---------------+-------+ // seq RC |i|i + ksize|len - ksize - i|len - i| @@ -317,146 +463,21 @@ impl Iterator for SeqToHashes { // +-+---------+---------------+-------+ // (leaving this table here because I had to draw to // get the indices correctly) - - let krc = &self.dna_rc[self.dna_len - self.dna_ksize - self.kmer_index - ..self.dna_len - self.kmer_index]; + let krc = &rc[frame.length() - self.k_size - self.kmer_index + ..frame.length() - self.kmer_index]; let hash = crate::_hash_murmur(std::cmp::min(kmer, krc), self.seed); - self.kmer_index += 1; - Some(Ok(hash)) - } else if self.hash_function.skipmer() { - // check that we can actually build skipmers - if self.k_size < 3 { - unimplemented!() - // return None - } - let extended_length = self.dna_ksize + ((self.dna_ksize + 1) / 2) - 1; // add 1 to round up rather than down - - // Check bounds to ensure we don't exceed the sequence length - if self.kmer_index + extended_length > self.sequence.len() { - return None; - } - - // Build skipmer with DNA base validation - let mut kmer: Vec = Vec::with_capacity(self.dna_ksize); - for (_i, &base) in self.sequence - [self.kmer_index..self.kmer_index + extended_length] - .iter() - .enumerate() - .filter(|&(i, _)| i % 3 != 2) - .take(self.dna_ksize) - { - // Use the validate_base method to check the base - if let Some(result) = self.validate_base(base, &kmer) { - self.kmer_index += 1; // Move to the next position if skipping is forced - return Some(result); - } - kmer.push(base); - } - - // Generate reverse complement skipmer - let krc: Vec = self.dna_rc[self.dna_len - extended_length - self.kmer_index - ..self.dna_len - self.kmer_index] - .iter() - .enumerate() - .filter(|&(i, _)| i % 3 != 2) - .take(self.dna_ksize) - .map(|(_, &base)| base) - .collect(); - - let hash = crate::_hash_murmur(std::cmp::min(&kmer, &krc), self.seed); - self.kmer_index += 1; - Some(Ok(hash)) - } else if self.hashes_buffer.is_empty() && self.translate_iter_step == 0 { - // Processing protein by translating DNA - // TODO: Implement iterator over frames instead of hashes_buffer. - - for frame_number in 0..3 { - let substr: Vec = self - .sequence - .iter() - .cloned() - .skip(frame_number) - .take(self.sequence.len() - frame_number) - .collect(); - - let aa = to_aa( - &substr, - self.hash_function.dayhoff(), - self.hash_function.hp(), - ) - .unwrap(); - - aa.windows(self.k_size).for_each(|n| { - let hash = crate::_hash_murmur(n, self.seed); - self.hashes_buffer.push(hash); - }); - - let rc_substr: Vec = self - .dna_rc - .iter() - .cloned() - .skip(frame_number) - .take(self.dna_rc.len() - frame_number) - .collect(); - let aa_rc = to_aa( - &rc_substr, - self.hash_function.dayhoff(), - self.hash_function.hp(), - ) - .unwrap(); - - aa_rc.windows(self.k_size).for_each(|n| { - let hash = crate::_hash_murmur(n, self.seed); - self.hashes_buffer.push(hash); - }); - } - Some(Ok(0)) - } else { - if self.translate_iter_step == self.hashes_buffer.len() { - self.hashes_buffer.clear(); - self.kmer_index = self.max_index; - return Some(Ok(0)); - } - let curr_idx = self.translate_iter_step; - self.translate_iter_step += 1; - Some(Ok(self.hashes_buffer[curr_idx])) + Ok(hash) } - } else { - // Processing protein - // The kmer size is already divided by 3 - - if self.hash_function.protein() { - let aa_kmer = &self.sequence[self.kmer_index..self.kmer_index + self.k_size]; - let hash = crate::_hash_murmur(aa_kmer, self.seed); - self.kmer_index += 1; - Some(Ok(hash)) - } else { - if !self.prot_configured { - self.aa_seq = match &self.hash_function { - HashFunctions::Murmur64Dayhoff => { - self.sequence.iter().cloned().map(aa_to_dayhoff).collect() - } - HashFunctions::Murmur64Hp => { - self.sequence.iter().cloned().map(aa_to_hp).collect() - } - invalid => { - return Some(Err(Error::InvalidHashFunction { - function: format!("{}", invalid), - })); - } - }; - } - - let aa_kmer = &self.aa_seq[self.kmer_index..self.kmer_index + self.k_size]; - let hash = crate::_hash_murmur(aa_kmer, self.seed); - self.kmer_index += 1; - Some(Ok(hash)) + ReadingFrame::Protein { .. } => { + let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size]; + Ok(crate::_hash_murmur(kmer, self.seed)) } - } - } else { - // End the iterator - None + }; + + self.kmer_index += 1; // Advance k-mer index for valid k-mers + return Some(result); } + None // No more frames or k-mers } } @@ -973,6 +994,7 @@ impl TryInto for Signature { #[cfg(test)] mod test { + use std::fs::File; use std::io::{BufReader, Read}; use std::path::PathBuf; @@ -1103,12 +1125,12 @@ mod test { } #[test] - fn signature_skipmer_add_sequence() { + fn signature_skipm2n3_add_sequence() { let params = ComputeParameters::builder() .ksizes(vec![3, 4, 5, 6]) .num_hashes(3u32) .dna(false) - .skipmer(true) + .skipm2n3(true) .build(); let mut sig = Signature::from_params(¶ms); @@ -1118,18 +1140,53 @@ mod test { dbg!(&sig.signatures); assert_eq!(sig.signatures[0].size(), 3); assert_eq!(sig.signatures[1].size(), 3); - assert_eq!(sig.signatures[2].size(), 2); + eprintln!("{:?}", sig.signatures[2]); + assert_eq!(sig.signatures[2].size(), 3); assert_eq!(sig.signatures[3].size(), 1); } + #[test] + fn signature_skipm1n3_add_sequence() { + let params = ComputeParameters::builder() + .ksizes(vec![3, 4, 5, 6]) + .num_hashes(10u32) + .dna(false) + .skipm1n3(true) + .build(); + + let mut sig = Signature::from_params(¶ms); + sig.add_sequence(b"ATGCATGAATGAC", false).unwrap(); + + assert_eq!(sig.signatures.len(), 4); + dbg!(&sig.signatures); + assert_eq!(sig.signatures[0].size(), 5); + assert_eq!(sig.signatures[1].size(), 4); + assert_eq!(sig.signatures[2].size(), 1); + assert_eq!(sig.signatures[3].size(), 0); + } + #[test] #[should_panic(expected = "not implemented")] - fn signature_skipmer_add_sequence_too_small() { + fn signature_skipm2n3_add_sequence_too_small() { let params = ComputeParameters::builder() .ksizes(vec![2]) - .num_hashes(3u32) + .num_hashes(10u32) .dna(false) - .skipmer(true) + .skipm2n3(true) + .build(); + + let mut sig = Signature::from_params(¶ms); + sig.add_sequence(b"ATGCATGA", false).unwrap(); + } + + #[test] + #[should_panic(expected = "not implemented")] + fn signature_skipm1n3_add_sequence_too_small() { + let params = ComputeParameters::builder() + .ksizes(vec![2]) + .num_hashes(10u32) + .dna(false) + .skipm1n3(true) .build(); let mut sig = Signature::from_params(¶ms); @@ -1375,19 +1432,213 @@ mod test { } } + #[test] + fn test_seqtohashes_frames_dna() { + let sequence = b"AGTCGT"; + let hash_function = HashFunctions::Murmur64Dna; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames.clone(); + + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].fw(), sequence.as_slice()); + assert_eq!(frames[0].rc(), b"ACGACT".as_slice()); + } + + #[test] + fn test_seqtohashes_frames_is_protein() { + let sequence = b"MVLSPADKTNVKAAW"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames.clone(); + + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].fw(), sequence.as_slice()); + } + + #[test] + #[should_panic] + fn test_seqtohashes_frames_is_protein_try_access_rc() { + // test panic if trying to access rc + let sequence = b"MVLSPADKTNVKAAW"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames.clone(); + + // protein frame doesn't have rc; this should panic + eprintln!("{:?}", frames[0].rc()); + } + + #[test] + fn test_seqtohashes_frames_is_protein_dayhoff() { + let sequence = b"MVLSPADKTNVKAAW"; + let dayhoff_seq = b"eeebbbcdbcedbbf"; + let hash_function = HashFunctions::Murmur64Dayhoff; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames.clone(); + + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].fw(), dayhoff_seq.as_slice()); + } + + #[test] + fn test_seqtohashes_frames_is_protein_hp() { + let sequence = b"MVLSPADKTNVKAAW"; + let hp_seq = b"hhhphhpppphphhh"; + let hash_function = HashFunctions::Murmur64Hp; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames.clone(); + + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].fw(), hp_seq.as_slice()); + } + + #[test] + fn test_seqtohashes_frames_translate_protein() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames.clone(); + + assert_eq!(frames[0].fw(), b"SRRA".as_slice()); + assert_eq!(frames[1].fw(), b"SSTT".as_slice()); + assert_eq!(frames[2].fw(), b"VVE".as_slice()); + assert_eq!(frames[3].fw(), b"ARR".as_slice()); + assert_eq!(frames[4].fw(), b"SSS".as_slice()); + assert_eq!(frames[5].fw(), b"LDD".as_slice()); + } + + #[test] + fn test_seqtohashes_frames_skipmer_m1n3() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Skipm1n3; // Represents m=1, n=3 + let k_size = 3; // K-mer size is not directly relevant for skipmer frame validation + let seed = 42; // Seed is also irrelevant for frame structure + let force = false; + let is_protein = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames.clone(); + + eprintln!("Frames: {:?}", frames); + + assert_eq!(frames.len(), 3); // Three skipmer frames + + // Expected skipmer sequences for m=1, n=3 (keep-1, skip-2) + assert_eq!(frames[0].fw(), b"ACCG".as_slice()); + assert_eq!(frames[0].rc(), b"CGGT".as_slice()); + + assert_eq!(frames[1].fw(), b"GGGC".as_slice()); + assert_eq!(frames[1].rc(), b"GCCC".as_slice()); + + assert_eq!(frames[2].fw(), b"TTAT".as_slice()); + assert_eq!(frames[2].rc(), b"ATAA".as_slice()); + } + + #[test] + fn test_seqtohashes_frames_skipmer_m2n3() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Skipm2n3; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames; + eprintln!("Frames: {:?}", frames); + + assert_eq!(frames.len(), 3); // Three skipmer frames + + // Expected skipmer sequences for m=1, n=3 (keep-1, skip-2) + assert_eq!(frames[0].fw(), b"AGCGCGGC".as_slice()); + assert_eq!(frames[0].rc(), b"GCCGCGCT".as_slice()); + + assert_eq!(frames[1].fw(), b"GTGTGACT".as_slice()); + assert_eq!(frames[1].rc(), b"AGTCACAC".as_slice()); + + assert_eq!(frames[2].fw(), b"TCTCAGT".as_slice()); + assert_eq!(frames[2].rc(), b"ACTGAGA".as_slice()); + } + #[test] fn test_seqtohashes_dna() { + let sequence = b"AGTCGT"; + let hash_function = HashFunctions::Murmur64Dna; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + + // Expected k-mers from the forward and reverse complement sequence + let expected_kmers = vec![ + (b"AGT".to_vec(), b"ACT".to_vec()), + (b"GTC".to_vec(), b"GAC".to_vec()), + (b"TCG".to_vec(), b"CGA".to_vec()), + (b"CGT".to_vec(), b"ACG".to_vec()), + ]; + + // Compute expected hashes from expected kmers + let expected_hashes: Vec = expected_kmers + .iter() + .map(|(fw_kmer, rc_kmer)| crate::_hash_murmur(std::cmp::min(fw_kmer, rc_kmer), seed)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); + } + + #[test] + fn test_seqtohashes_dna_2() { let sequence = b"AGTCGTCA"; let k_size = 7; let seed = 42; let force = true; // Force skip over invalid bases if needed - + let is_protein = false; // Initialize SeqToHashes iterator using the new constructor let mut seq_to_hashes = SeqToHashes::new( sequence, k_size, force, - false, + is_protein, HashFunctions::Murmur64Dna, seed, ); @@ -1414,42 +1665,195 @@ mod test { } #[test] - fn test_seqtohashes_skipmer() { - let sequence = b"AGTCGTCA"; - // let rc_seq = b"TGACGACT"; - let k_size = 5; + fn test_seqtohashes_is_protein() { + let sequence = b"MVLSPADKTNVKAAW"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 3; let seed = 42; - let force = true; // Force skip over invalid bases if needed + let force = false; + let is_protein = true; + + let sth = SeqToHashes::new(sequence, k_size * 3, force, is_protein, hash_function, seed); + + // Expected k-mers for protein sequence + let expected_kmers = vec![ + b"MVL".to_vec(), + b"VLS".to_vec(), + b"LSP".to_vec(), + b"SPA".to_vec(), + b"PAD".to_vec(), + b"ADK".to_vec(), + b"DKT".to_vec(), + b"KTN".to_vec(), + b"TNV".to_vec(), + b"NVK".to_vec(), + b"VKA".to_vec(), + b"KAA".to_vec(), + b"AAW".to_vec(), + ]; + + // Compute hashes for expected k-mers + let expected_hashes: Vec = expected_kmers + .iter() + .map(|fw_kmer| crate::_hash_murmur(fw_kmer, 42)) + .collect(); - // Initialize SeqToHashes iterator using the new constructor - let mut seq_to_hashes = SeqToHashes::new( - sequence, - k_size, - force, - false, - HashFunctions::Murmur64Skipmer, - seed, + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!(sth_hashes, expected_hashes, "Hashes do not match in order"); + } + + #[test] + fn test_seqtohashes_translate() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 9; // needs to be *3 for protein + let seed = 42; + let force = false; + let is_protein = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + + let expected_kmers = vec![ + b"SRR".as_slice(), + b"RRA".as_slice(), + b"SST".as_slice(), + b"STT".as_slice(), + b"VVE".as_slice(), + b"ARR".as_slice(), + b"SSS".as_slice(), + b"LDD".as_slice(), + ]; + + // Compute expected hashes + let expected_hashes: Vec = expected_kmers + .iter() + .map(|fw_kmer| crate::_hash_murmur(fw_kmer, seed)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" ); + } + + #[test] + fn test_seqtohashes_skipm1n3() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Skipm1n3; + let k_size = 3; + let is_protein = false; + let seed = 42; + let force = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + // Expected k-mers for skipmer (m=1, n=3) across all frames + let expected_kmers = vec![ + (b"ACC".as_slice(), b"GGT".as_slice()), + (b"CCG".as_slice(), b"CGG".as_slice()), + (b"GGG".as_slice(), b"CCC".as_slice()), + (b"GGC".as_slice(), b"GCC".as_slice()), + (b"TTA".as_slice(), b"TAA".as_slice()), + (b"TAT".as_slice(), b"ATA".as_slice()), + ]; + + // Compute expected hashes + let expected_hashes: Vec = expected_kmers + .iter() + .map(|(fw_kmer, rc_kmer)| crate::_hash_murmur(std::cmp::min(fw_kmer, rc_kmer), seed)) + .collect(); - // Define expected hashes for the skipmer configuration. - let expected_kmers = ["AGCGC", "GTGTA"]; - // rc of the k-mer, not of the sequence, then skipmerized. Correct? - let expected_krc = ["GCGCT", "TACAC"]; + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); - // Compute expected hashes by hashing each k-mer with its reverse complement + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); + } + + #[test] + fn test_seq2hashes_skipm2n3() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Skipm2n3; + let k_size = 7; + let is_protein = false; + let seed = 42; + let force = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + + // Expected k-mers for skipmer (m=2, n=3) + let expected_kmers = vec![ + (b"AGCGCGG".as_slice(), b"CCGCGCT".as_slice()), + (b"GCGCGGC".as_slice(), b"GCCGCGC".as_slice()), + (b"GTGTGAC".as_slice(), b"GTCACAC".as_slice()), + (b"TGTGACT".as_slice(), b"AGTCACA".as_slice()), + (b"TCTCAGT".as_slice(), b"ACTGAGA".as_slice()), + ]; + + // Compute expected hashes let expected_hashes: Vec = expected_kmers .iter() - .zip(expected_krc.iter()) - .map(|(kmer, krc)| { - // Convert both kmer and krc to byte slices and pass to _hash_murmur - crate::_hash_murmur(std::cmp::min(kmer.as_bytes(), krc.as_bytes()), seed) - }) + .map(|(fw_kmer, rc_kmer)| crate::_hash_murmur(std::cmp::min(fw_kmer, rc_kmer), seed)) .collect(); - // Compare each produced hash from the iterator with the expected hash - for expected_hash in expected_hashes { - let hash = seq_to_hashes.next().unwrap().ok().unwrap(); - assert_eq!(hash, expected_hash, "Mismatch in skipmer hash"); + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); + } + + #[test] + fn test_seqtohashes_skipm2n3_2() { + let sequence = b"AGTCGTCA"; + let hash_function = HashFunctions::Murmur64Skipm2n3; + let k_size = 5; + let seed = 42; + let force = true; + let is_protein = false; + + let sth = SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed); + let frames = sth.frames.clone(); + for fr in frames { + eprintln!("{}", fr); } + + let expected_kmers = vec![ + (b"AGCGC".as_slice(), b"GCGCT".as_slice()), + (b"GCGCA".as_slice(), b"TGCGC".as_slice()), + (b"GTGTA".as_slice(), b"TACAC".as_slice()), + ]; + + // Compute expected hashes + let expected_hashes: Vec = expected_kmers + .iter() + .map(|(fw_kmer, rc_kmer)| crate::_hash_murmur(std::cmp::min(fw_kmer, rc_kmer), seed)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); } } diff --git a/src/core/src/sketch/minhash.rs b/src/core/src/sketch/minhash.rs index 3c8f3b4240..d8ca2337a0 100644 --- a/src/core/src/sketch/minhash.rs +++ b/src/core/src/sketch/minhash.rs @@ -153,7 +153,8 @@ impl<'de> Deserialize<'de> for KmerMinHash { "dayhoff" => HashFunctions::Murmur64Dayhoff, "hp" => HashFunctions::Murmur64Hp, "dna" => HashFunctions::Murmur64Dna, - "skipmer" => HashFunctions::Murmur64Skipmer, + "skipm1n3" => HashFunctions::Murmur64Skipm1n3, + "skipm2n3" => HashFunctions::Murmur64Skipm2n3, _ => unimplemented!(), // TODO: throw error here }; diff --git a/tests/test_sourmash_sketch.py b/tests/test_sourmash_sketch.py index 88332b1945..a9af0fdd24 100644 --- a/tests/test_sourmash_sketch.py +++ b/tests/test_sourmash_sketch.py @@ -345,6 +345,7 @@ def test_protein_override_bad_rust_foo(): siglist = factory() assert len(siglist) == 1 sig = siglist[0] + print(sig.minhash.ksize) # try adding something testdata1 = utils.get_test_data("ecoli.faa") @@ -354,7 +355,7 @@ def test_protein_override_bad_rust_foo(): with pytest.raises(ValueError) as exc: sig.add_protein(record.sequence) - assert 'Invalid hash function: "DNA"' in str(exc) + assert "invalid DNA character in input k-mer: MRVLKFGGTS" in str(exc) def test_dayhoff_defaults():