From 5680efc994362dcdcd6b6dfd5f58b49cce254817 Mon Sep 17 00:00:00 2001 From: Tessa Pierce Ward Date: Fri, 20 Dec 2024 09:47:49 -0800 Subject: [PATCH] MRG: add skipmers; switch to reading frame approach for translation, skipmers (#3395) This PR enables skipmers **ONLY in the rust code**. - enables two skipmer types: m1n3, m2n3 - switches `SeqToHashes` to use reading frame struct, which simplifies/unifies the code across the different methods. The reading frame code handles any modifications needed - i.e. translation or skipping. Then we just kmerize the reading frame as usual. The main difference for translation is that we no longer need to store a buffer of all hashes from the reading frames. Since this changes the `SeqToHashes` strategy a bit, there's one python test where we now see a different error (modified). Future thoughts: - with the new structure, it would be straightforward to add validation for protein k-mers. I guess I'm not entirely sure what happens to those atm... Skipmer References: - [Skip-mers: increasing entropy and sensitivity to detect conserved genic regions with simple cyclic q-grams](https://www.biorxiv.org/content/10.1101/179960.abstract) - [Extracting and Evaluating Features from RNA Virus Sequences to Predict Host Species Susceptibility Using Deep Learning](https://dl.acm.org/doi/abs/10.1145/3473258.3473271) --- .github/dependabot.yml | 2 + include/sourmash.h | 3 + src/core/Cargo.toml | 4 +- src/core/src/cmd.rs | 42 ++ src/core/src/encodings.rs | 105 +++ src/core/src/errors.rs | 15 + src/core/src/ffi/minhash.rs | 15 +- src/core/src/signature.rs | 1131 +++++++++++++++++++++++++++----- src/core/src/sketch/minhash.rs | 2 + src/core/tests/minhash.rs | 8 +- tests/test_sourmash_sketch.py | 3 +- 11 files changed, 1150 insertions(+), 180 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index d01a55a915..50a37271e5 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -17,6 +17,8 @@ updates: - dependency-name: "wasm-bindgen" - dependency-name: "once_cell" - dependency-name: "chrono" + - dependency-name: "js-sys" + - dependency-name: "web-sys" - package-ecosystem: "github-actions" directory: "/" schedule: diff --git a/include/sourmash.h b/include/sourmash.h index f650035dec..8e80d22834 100644 --- a/include/sourmash.h +++ b/include/sourmash.h @@ -38,6 +38,9 @@ enum SourmashErrorCode { SOURMASH_ERROR_CODE_INVALID_PROT = 1102, SOURMASH_ERROR_CODE_INVALID_CODON_LENGTH = 1103, SOURMASH_ERROR_CODE_INVALID_HASH_FUNCTION = 1104, + SOURMASH_ERROR_CODE_INVALID_SKIPMER_FRAME = 1105, + SOURMASH_ERROR_CODE_INVALID_SKIPMER_SIZE = 1106, + SOURMASH_ERROR_CODE_INVALID_TRANSLATE_FRAME = 1107, SOURMASH_ERROR_CODE_READ_DATA = 1201, SOURMASH_ERROR_CODE_STORAGE = 1202, SOURMASH_ERROR_CODE_HLL_PRECISION_BOUNDS = 1301, diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 221fd240a5..f5b41078d9 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -96,8 +96,8 @@ skip_feature_sets = [ ## Wasm section. Crates only used for WASM, as well as specific configurations [target.'cfg(all(target_arch = "wasm32", target_os="unknown"))'.dependencies] -js-sys = "0.3.74" -web-sys = { version = "0.3.74", features = ["console", "File", "FileReaderSync"] } +js-sys = "0.3.72" +web-sys = { version = "0.3.72", features = ["console", "File", "FileReaderSync"] } wasm-bindgen = "0.2.89" getrandom = { version = "0.2", features = ["js"] } diff --git a/src/core/src/cmd.rs b/src/core/src/cmd.rs index b59e99b912..d5bfb092ba 100644 --- a/src/core/src/cmd.rs +++ b/src/core/src/cmd.rs @@ -42,6 +42,14 @@ pub struct ComputeParameters { #[builder(default = false)] hp: bool, + #[getset(get_copy = "pub", set = "pub")] + #[builder(default = false)] + skipm1n3: bool, + + #[getset(get_copy = "pub", set = "pub")] + #[builder(default = false)] + skipm2n3: bool, + #[getset(get_copy = "pub", set = "pub")] #[builder(default = false)] singleton: bool, @@ -165,6 +173,40 @@ pub fn build_template(params: &ComputeParameters) -> Vec { )); } + if params.skipm1n3 { + ksigs.push(Sketch::LargeMinHash( + KmerMinHashBTree::builder() + .num(params.num_hashes) + .ksize(*k) + .hash_function(HashFunctions::Murmur64Skipm1n3) + .max_hash(max_hash) + .seed(params.seed) + .abunds(if params.track_abundance { + Some(Default::default()) + } else { + None + }) + .build(), + )); + } + + if params.skipm2n3 { + ksigs.push(Sketch::LargeMinHash( + KmerMinHashBTree::builder() + .num(params.num_hashes) + .ksize(*k) + .hash_function(HashFunctions::Murmur64Skipm2n3) + .max_hash(max_hash) + .seed(params.seed) + .abunds(if params.track_abundance { + Some(Default::default()) + } else { + None + }) + .build(), + )); + } + if params.dna { ksigs.push(Sketch::LargeMinHash( KmerMinHashBTree::builder() diff --git a/src/core/src/encodings.rs b/src/core/src/encodings.rs index 8375699cb0..06af5d03c1 100644 --- a/src/core/src/encodings.rs +++ b/src/core/src/encodings.rs @@ -31,6 +31,8 @@ pub enum HashFunctions { Murmur64Protein, Murmur64Dayhoff, Murmur64Hp, + Murmur64Skipm1n3, + Murmur64Skipm2n3, Custom(String), } @@ -50,6 +52,14 @@ impl HashFunctions { pub fn hp(&self) -> bool { *self == HashFunctions::Murmur64Hp } + + pub fn skipm1n3(&self) -> bool { + *self == HashFunctions::Murmur64Skipm1n3 + } + + pub fn skipm2n3(&self) -> bool { + *self == HashFunctions::Murmur64Skipm2n3 + } } impl std::fmt::Display for HashFunctions { @@ -62,6 +72,8 @@ impl std::fmt::Display for HashFunctions { HashFunctions::Murmur64Protein => "protein", HashFunctions::Murmur64Dayhoff => "dayhoff", HashFunctions::Murmur64Hp => "hp", + HashFunctions::Murmur64Skipm1n3 => "skipm1n3", + HashFunctions::Murmur64Skipm2n3 => "skipm2n3", HashFunctions::Custom(v) => v, } ) @@ -77,6 +89,8 @@ impl TryFrom<&str> for HashFunctions { "dayhoff" => Ok(HashFunctions::Murmur64Dayhoff), "hp" => Ok(HashFunctions::Murmur64Hp), "protein" => Ok(HashFunctions::Murmur64Protein), + "skipm1n3" => Ok(HashFunctions::Murmur64Skipm1n3), + "skipm2n3" => Ok(HashFunctions::Murmur64Skipm2n3), v => unimplemented!("{v}"), } } @@ -506,6 +520,7 @@ impl<'a> Iterator for Indices<'a> { #[cfg(test)] mod test { use super::*; + use std::convert::TryFrom; #[test] fn colors_update() { @@ -573,4 +588,94 @@ mod test { assert_eq!(colors.len(), 2); } + + #[test] + fn test_dna_method() { + assert!(HashFunctions::Murmur64Dna.dna()); + assert!(!HashFunctions::Murmur64Protein.dna()); + assert!(!HashFunctions::Murmur64Dayhoff.dna()); + } + + #[test] + fn test_protein_method() { + assert!(HashFunctions::Murmur64Protein.protein()); + assert!(!HashFunctions::Murmur64Dna.protein()); + assert!(!HashFunctions::Murmur64Dayhoff.protein()); + } + + #[test] + fn test_dayhoff_method() { + assert!(HashFunctions::Murmur64Dayhoff.dayhoff()); + assert!(!HashFunctions::Murmur64Dna.dayhoff()); + assert!(!HashFunctions::Murmur64Protein.dayhoff()); + } + + #[test] + fn test_hp_method() { + assert!(HashFunctions::Murmur64Hp.hp()); + assert!(!HashFunctions::Murmur64Dna.hp()); + assert!(!HashFunctions::Murmur64Protein.hp()); + } + + #[test] + fn test_skipm1n3_method() { + assert!(HashFunctions::Murmur64Skipm1n3.skipm1n3()); + assert!(!HashFunctions::Murmur64Dna.skipm1n3()); + assert!(!HashFunctions::Murmur64Protein.skipm1n3()); + } + + #[test] + fn test_skipm2n3_method() { + assert!(HashFunctions::Murmur64Skipm2n3.skipm2n3()); + assert!(!HashFunctions::Murmur64Dna.skipm2n3()); + assert!(!HashFunctions::Murmur64Protein.skipm2n3()); + } + + #[test] + fn test_display_hashfunctions() { + assert_eq!(HashFunctions::Murmur64Dna.to_string(), "DNA"); + assert_eq!(HashFunctions::Murmur64Protein.to_string(), "protein"); + assert_eq!(HashFunctions::Murmur64Dayhoff.to_string(), "dayhoff"); + assert_eq!(HashFunctions::Murmur64Hp.to_string(), "hp"); + assert_eq!(HashFunctions::Murmur64Skipm1n3.to_string(), "skipm1n3"); + assert_eq!(HashFunctions::Murmur64Skipm2n3.to_string(), "skipm2n3"); + assert_eq!( + HashFunctions::Custom("custom_string".into()).to_string(), + "custom_string" + ); + } + + #[test] + fn test_try_from_str_valid() { + assert_eq!( + HashFunctions::try_from("dna").unwrap(), + HashFunctions::Murmur64Dna + ); + assert_eq!( + HashFunctions::try_from("protein").unwrap(), + HashFunctions::Murmur64Protein + ); + assert_eq!( + HashFunctions::try_from("dayhoff").unwrap(), + HashFunctions::Murmur64Dayhoff + ); + assert_eq!( + HashFunctions::try_from("hp").unwrap(), + HashFunctions::Murmur64Hp + ); + assert_eq!( + HashFunctions::try_from("skipm1n3").unwrap(), + HashFunctions::Murmur64Skipm1n3 + ); + assert_eq!( + HashFunctions::try_from("skipm2n3").unwrap(), + HashFunctions::Murmur64Skipm2n3 + ); + } + + #[test] + #[should_panic(expected = "not implemented: unknown")] + fn test_try_from_str_invalid() { + HashFunctions::try_from("unknown").unwrap(); + } } diff --git a/src/core/src/errors.rs b/src/core/src/errors.rs index e8ce3e68aa..870d846f45 100644 --- a/src/core/src/errors.rs +++ b/src/core/src/errors.rs @@ -55,6 +55,15 @@ pub enum SourmashError { #[error("Codon is invalid length: {message}")] InvalidCodonLength { message: String }, + #[error("Skipmer ksize must be >= n ({n}), but got ksize: {ksize}")] + InvalidSkipmerSize { ksize: usize, n: usize }, + + #[error("Skipmer frame number must be < n ({n}), but got start: {start}")] + InvalidSkipmerFrame { start: usize, n: usize }, + + #[error("Frame number must be 0, 1, or 2, but got {frame_number}")] + InvalidTranslateFrame { frame_number: usize }, + #[error("Set error rate to a value smaller than 0.367696 and larger than 0.00203125")] HLLPrecisionBounds, @@ -128,6 +137,9 @@ pub enum SourmashErrorCode { InvalidProt = 11_02, InvalidCodonLength = 11_03, InvalidHashFunction = 11_04, + InvalidSkipmerFrame = 11_05, + InvalidSkipmerSize = 11_06, + InvalidTranslateFrame = 11_07, // index-related errors ReadData = 12_01, Storage = 12_02, @@ -170,6 +182,9 @@ impl SourmashErrorCode { SourmashError::InvalidProt { .. } => SourmashErrorCode::InvalidProt, SourmashError::InvalidCodonLength { .. } => SourmashErrorCode::InvalidCodonLength, SourmashError::InvalidHashFunction { .. } => SourmashErrorCode::InvalidHashFunction, + SourmashError::InvalidSkipmerFrame { .. } => SourmashErrorCode::InvalidSkipmerFrame, + SourmashError::InvalidSkipmerSize { .. } => SourmashErrorCode::InvalidSkipmerSize, + SourmashError::InvalidTranslateFrame { .. } => SourmashErrorCode::InvalidTranslateFrame, SourmashError::ReadDataError { .. } => SourmashErrorCode::ReadData, SourmashError::StorageError { .. } => SourmashErrorCode::Storage, SourmashError::HLLPrecisionBounds { .. } => SourmashErrorCode::HLLPrecisionBounds, diff --git a/src/core/src/ffi/minhash.rs b/src/core/src/ffi/minhash.rs index 55879dd060..baad446a0f 100644 --- a/src/core/src/ffi/minhash.rs +++ b/src/core/src/ffi/minhash.rs @@ -73,15 +73,26 @@ Result<*const u64> { let mut output: Vec = Vec::with_capacity(insize); + // Call SeqToHashes::new and handle errors + let ready_hashes = SeqToHashes::new( + buf, + mh.ksize(), + force, + is_protein, + mh.hash_function(), + mh.seed(), + )?; + + if force && bad_kmers_as_zeroes{ - for hash_value in SeqToHashes::new(buf, mh.ksize(), force, is_protein, mh.hash_function(), mh.seed()){ + for hash_value in ready_hashes{ match hash_value{ Ok(x) => output.push(x), Err(err) => return Err(err), } } }else{ - for hash_value in SeqToHashes::new(buf, mh.ksize(), force, is_protein, mh.hash_function(), mh.seed()){ + for hash_value in ready_hashes { match hash_value{ Ok(0) => continue, Ok(x) => output.push(x), diff --git a/src/core/src/signature.rs b/src/core/src/signature.rs index a3971a8637..b2498acd57 100644 --- a/src/core/src/signature.rs +++ b/src/core/src/signature.rs @@ -10,12 +10,14 @@ use std::path::Path; use std::str; use cfg_if::cfg_if; +use itertools::Itertools; #[cfg(feature = "parallel")] use rayon::prelude::*; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use crate::encodings::{aa_to_dayhoff, aa_to_hp, revcomp, to_aa, HashFunctions, VALID}; +use crate::errors::SourmashError; use crate::prelude::*; use crate::sketch::minhash::KmerMinHash; use crate::sketch::Sketch; @@ -43,7 +45,7 @@ pub trait SigsTrait { false, self.hash_function(), self.seed(), - ); + )?; for hash_value in ready_hashes { match hash_value { @@ -65,7 +67,7 @@ pub trait SigsTrait { true, self.hash_function(), self.seed(), - ); + )?; for hash_value in ready_hashes { match hash_value { @@ -163,27 +165,159 @@ impl SigsTrait for Sketch { } } -// Iterator for converting sequence to hashes +#[derive(Debug, Clone)] +pub enum ReadingFrame { + DNA { + fw: Vec, + rc: Vec, + len: usize, // len gives max_index for kmer iterator + }, + Protein { + fw: Vec, + len: usize, + }, +} + +impl std::fmt::Display for ReadingFrame { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ReadingFrame::DNA { fw, rc, len } => { + let fw_str = String::from_utf8_lossy(fw).to_string(); + let rc_str = String::from_utf8_lossy(rc).to_string(); + write!( + f, + "Type: DNA ({}bp), Forward: {}, Reverse Complement: {}", + len, fw_str, rc_str + ) + } + ReadingFrame::Protein { fw, len } => { + let fw_str = String::from_utf8_lossy(fw).to_string(); + write!(f, "Type: Protein ({}aa), Forward: {}", len, fw_str) + } + } + } +} + +impl ReadingFrame { + pub fn new_dna(sequence: &[u8]) -> Self { + let fw = sequence.to_ascii_uppercase(); + let rc = revcomp(&fw); + let len = sequence.len(); + ReadingFrame::DNA { fw, rc, len } + } + + pub fn new_protein(sequence: &[u8], dayhoff: bool, hp: bool) -> Self { + let seq = sequence.to_ascii_uppercase(); + let fw: Vec = if dayhoff { + seq.iter().map(|&aa| aa_to_dayhoff(aa)).collect() + } else if hp { + seq.iter().map(|&aa| aa_to_hp(aa)).collect() + } else { + seq + }; + + let len = fw.len(); + ReadingFrame::Protein { fw, len } + } + + pub fn new_skipmer( + sequence: &[u8], + start: usize, + m: usize, + n: usize, + ) -> Result { + let seq = sequence.to_ascii_uppercase(); + if start >= n { + return Err(SourmashError::InvalidSkipmerFrame { start, n }); + } + // do we need to round up? (+1) + let mut fw = Vec::with_capacity(((seq.len() * m) + 1) / n); + seq.iter().skip(start).enumerate().for_each(|(i, &base)| { + if i % n < m { + fw.push(base.to_ascii_uppercase()); + } + }); + + let len = fw.len(); + let rc = revcomp(&fw); + Ok(ReadingFrame::DNA { fw, rc, len }) + } + + // this is the only one that doesn't uppercase in here b/c more efficient to uppercase externally :/ + pub fn new_translated( + sequence: &[u8], + frame_number: usize, + dayhoff: bool, + hp: bool, + ) -> Result { + if frame_number > 2 { + return Err(SourmashError::InvalidTranslateFrame { frame_number }); + } + + // Translate sequence into amino acids + let mut fw = Vec::with_capacity(sequence.len() / 3); + // NOTE: b/c of chunks(3), we only process full codons and ignore leftover bases (e.g. 1 or 2 at end of frame) + sequence + .iter() + .skip(frame_number) // Skip the initial bases for the frame + .take(sequence.len() - frame_number) // Adjust length based on skipped bases + .chunks(3) // Group into codons (triplets) using itertools + .into_iter() + .filter_map(|chunk| { + let codon: Vec = chunk.cloned().collect(); // Collect the chunk into a Vec + to_aa(&codon, dayhoff, hp).ok() // Translate the codon + }) + .for_each(|aa| fw.extend(aa)); // Extend `fw` with amino acids + + let len = fw.len(); + + // return protein reading frame + Ok(ReadingFrame::Protein { fw, len }) + } + + /// Get the forward sequence. + #[inline] + pub fn fw(&self) -> &[u8] { + match self { + ReadingFrame::DNA { fw, .. } => fw, + ReadingFrame::Protein { fw, .. } => fw, + } + } + + /// Get the reverse complement sequence (if DNA). + #[inline] + pub fn rc(&self) -> &[u8] { + match self { + ReadingFrame::DNA { rc, .. } => rc, + _ => panic!("Reverse complement is only available for DNA frames"), + } + } + + #[inline] + pub fn length(&self) -> usize { + match self { + ReadingFrame::DNA { len, .. } => *len, + ReadingFrame::Protein { len, .. } => *len, + } + } + + /// Get the type of the frame as a string. + pub fn frame_type(&self) -> &'static str { + match self { + ReadingFrame::DNA { .. } => "DNA", + ReadingFrame::Protein { .. } => "Protein", + } + } +} + pub struct SeqToHashes { - sequence: Vec, - kmer_index: usize, k_size: usize, - max_index: usize, force: bool, - is_protein: bool, - hash_function: HashFunctions, seed: u64, - hashes_buffer: Vec, - - dna_configured: bool, - dna_rc: Vec, - dna_ksize: usize, - dna_len: usize, - dna_last_position_check: usize, - - prot_configured: bool, - aa_seq: Vec, - translate_iter_step: usize, + frames: Vec, + frame_index: usize, // Index of the current frame + kmer_index: usize, // Current k-mer index within the frame + last_position_check: usize, // Index of last base we validated } impl SeqToHashes { @@ -194,99 +328,150 @@ impl SeqToHashes { is_protein: bool, hash_function: HashFunctions, seed: u64, - ) -> SeqToHashes { + ) -> Result { let mut ksize: usize = k_size; - // Divide the kmer size by 3 if protein - if is_protein || !hash_function.dna() { + // Adjust kmer size for protein-based hash functions + if is_protein || hash_function.protein() || hash_function.dayhoff() || hash_function.hp() { ksize = k_size / 3; } - // By setting _max_index to 0, the iterator will return None and exit - let _max_index = if seq.len() >= ksize { - seq.len() - ksize + 1 + // Generate frames based on sequence type and hash function + let frames = if hash_function.dna() { + Self::dna_frames(seq) + } else if is_protein { + Self::protein_frames(seq, &hash_function) + } else if hash_function.protein() || hash_function.dayhoff() || hash_function.hp() { + Self::translated_frames(seq, &hash_function)? + } else if hash_function.skipm1n3() || hash_function.skipm2n3() { + Self::skipmer_frames(seq, &hash_function, ksize)? } else { - 0 + return Err(SourmashError::InvalidHashFunction { + function: format!("{:?}", hash_function), + }); }; - SeqToHashes { - // Here we convert the sequence to upper case - sequence: seq.to_ascii_uppercase(), + Ok(SeqToHashes { k_size: ksize, - kmer_index: 0, - max_index: _max_index, force, - is_protein, - hash_function, seed, - hashes_buffer: Vec::with_capacity(1000), - dna_configured: false, - dna_rc: Vec::with_capacity(1000), - dna_ksize: 0, - dna_len: 0, - dna_last_position_check: 0, - prot_configured: false, - aa_seq: Vec::new(), - translate_iter_step: 0, + frames, + frame_index: 0, + kmer_index: 0, + last_position_check: 0, + }) + } + + /// generate frames from DNA: 1 DNA frame (fw+rc) + fn dna_frames(seq: &[u8]) -> Vec { + vec![ReadingFrame::new_dna(seq)] + } + + /// generate frames from protein: 1 protein frame + fn protein_frames(seq: &[u8], hash_function: &HashFunctions) -> Vec { + vec![ReadingFrame::new_protein( + seq, + hash_function.dayhoff(), + hash_function.hp(), + )] + } + + /// generate translated frames: 6 protein frames + fn translated_frames( + seq: &[u8], + hash_function: &HashFunctions, + ) -> Result, SourmashError> { + // since we need to revcomp BEFORE making ReadingFrames, uppercase the sequence here + let sequence = seq.to_ascii_uppercase(); + let revcomp_sequence = revcomp(&sequence); + let frames = (0..3) + .flat_map(|frame_number| { + vec![ + ReadingFrame::new_translated( + &sequence, + frame_number, + hash_function.dayhoff(), + hash_function.hp(), + ), + ReadingFrame::new_translated( + &revcomp_sequence, + frame_number, + hash_function.dayhoff(), + hash_function.hp(), + ), + ] + }) + .collect::, _>>()?; + + Ok(frames) + } + + /// generate skipmer frames: 3 DNA frames (each with fw+rc) + fn skipmer_frames( + seq: &[u8], + hash_function: &HashFunctions, + ksize: usize, + ) -> Result, SourmashError> { + let (m, n) = if hash_function.skipm1n3() { + (1, 3) + } else { + (2, 3) + }; + if ksize < n { + return Err(SourmashError::InvalidSkipmerSize { ksize, n }); } + let frames = (0..3) + .flat_map(|frame_number| vec![ReadingFrame::new_skipmer(seq, frame_number, m, n)]) + .collect::, _>>()?; + + Ok(frames) } -} -/* -Iterator that return a kmer hash for all modes except translate. -In translate mode: - - all the frames are processed at once and converted to hashes. - - all the hashes are stored in `hashes_buffer` - - after processing all the kmers, `translate_iter_step` is incremented - per iteration to iterate over all the indeces of the `hashes_buffer`. - - the iterator will die once `translate_iter_step` == length(hashes_buffer) -More info https://github.com/sourmash-bio/sourmash/pull/1946 -*/ + fn out_of_bounds(&self, frame: &ReadingFrame) -> bool { + self.kmer_index + self.k_size > frame.length() + } +} impl Iterator for SeqToHashes { type Item = Result; fn next(&mut self) -> Option { - if (self.kmer_index < self.max_index) || !self.hashes_buffer.is_empty() { - // Processing DNA or Translated DNA - if !self.is_protein { - // Setting the parameters only in the first iteration - if !self.dna_configured { - self.dna_ksize = self.k_size; - self.dna_len = self.sequence.len(); - if self.dna_len < self.dna_ksize - || (!self.hash_function.dna() && self.dna_len < self.k_size * 3) - { - return None; - } - // pre-calculate the reverse complement for the full sequence... - self.dna_rc = revcomp(&self.sequence); - self.dna_configured = true; - } + while self.frame_index < self.frames.len() { + let frame = &self.frames[self.frame_index]; + + // Do we need to move to the next frame? + if self.out_of_bounds(frame) { + self.frame_index += 1; + self.kmer_index = 0; // Reset for the next frame + self.last_position_check = 0; + continue; + } - // Processing DNA - if self.hash_function.dna() { - let kmer = &self.sequence[self.kmer_index..self.kmer_index + self.dna_ksize]; + let result = match frame { + ReadingFrame::DNA { .. } => { + let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size]; + let rc = frame.rc(); - for j in std::cmp::max(self.kmer_index, self.dna_last_position_check) - ..self.kmer_index + self.dna_ksize + // Validate k-mer bases + for j in std::cmp::max(self.kmer_index, self.last_position_check) + ..self.kmer_index + self.k_size { - if !VALID[self.sequence[j] as usize] { + if !VALID[frame.fw()[j] as usize] { if !self.force { + // Return an error if force is false return Some(Err(Error::InvalidDNA { message: String::from_utf8(kmer.to_vec()).unwrap(), })); } else { + // Skip the invalid k-mer self.kmer_index += 1; - // Move the iterator to the next step return Some(Ok(0)); } } - self.dna_last_position_check += 1; + self.last_position_check += 1; } - // ... and then while moving the k-mer window forward for the sequence - // we move another window backwards for the RC. + // Compute canonical hash // For a ksize = 3, and a sequence AGTCGT (len = 6): // +-+---------+---------------+-------+ // seq RC |i|i + ksize|len - ksize - i|len - i| @@ -298,103 +483,21 @@ impl Iterator for SeqToHashes { // +-+---------+---------------+-------+ // (leaving this table here because I had to draw to // get the indices correctly) - - let krc = &self.dna_rc[self.dna_len - self.dna_ksize - self.kmer_index - ..self.dna_len - self.kmer_index]; + let krc = &rc[frame.length() - self.k_size - self.kmer_index + ..frame.length() - self.kmer_index]; let hash = crate::_hash_murmur(std::cmp::min(kmer, krc), self.seed); - self.kmer_index += 1; - Some(Ok(hash)) - } else if self.hashes_buffer.is_empty() && self.translate_iter_step == 0 { - // Processing protein by translating DNA - // TODO: Implement iterator over frames instead of hashes_buffer. - - for frame_number in 0..3 { - let substr: Vec = self - .sequence - .iter() - .cloned() - .skip(frame_number) - .take(self.sequence.len() - frame_number) - .collect(); - - let aa = to_aa( - &substr, - self.hash_function.dayhoff(), - self.hash_function.hp(), - ) - .unwrap(); - - aa.windows(self.k_size).for_each(|n| { - let hash = crate::_hash_murmur(n, self.seed); - self.hashes_buffer.push(hash); - }); - - let rc_substr: Vec = self - .dna_rc - .iter() - .cloned() - .skip(frame_number) - .take(self.dna_rc.len() - frame_number) - .collect(); - let aa_rc = to_aa( - &rc_substr, - self.hash_function.dayhoff(), - self.hash_function.hp(), - ) - .unwrap(); - - aa_rc.windows(self.k_size).for_each(|n| { - let hash = crate::_hash_murmur(n, self.seed); - self.hashes_buffer.push(hash); - }); - } - Some(Ok(0)) - } else { - if self.translate_iter_step == self.hashes_buffer.len() { - self.hashes_buffer.clear(); - self.kmer_index = self.max_index; - return Some(Ok(0)); - } - let curr_idx = self.translate_iter_step; - self.translate_iter_step += 1; - Some(Ok(self.hashes_buffer[curr_idx])) + Ok(hash) } - } else { - // Processing protein - // The kmer size is already divided by 3 - - if self.hash_function.protein() { - let aa_kmer = &self.sequence[self.kmer_index..self.kmer_index + self.k_size]; - let hash = crate::_hash_murmur(aa_kmer, self.seed); - self.kmer_index += 1; - Some(Ok(hash)) - } else { - if !self.prot_configured { - self.aa_seq = match &self.hash_function { - HashFunctions::Murmur64Dayhoff => { - self.sequence.iter().cloned().map(aa_to_dayhoff).collect() - } - HashFunctions::Murmur64Hp => { - self.sequence.iter().cloned().map(aa_to_hp).collect() - } - invalid => { - return Some(Err(Error::InvalidHashFunction { - function: format!("{}", invalid), - })); - } - }; - } - - let aa_kmer = &self.aa_seq[self.kmer_index..self.kmer_index + self.k_size]; - let hash = crate::_hash_murmur(aa_kmer, self.seed); - self.kmer_index += 1; - Some(Ok(hash)) + ReadingFrame::Protein { .. } => { + let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size]; + Ok(crate::_hash_murmur(kmer, self.seed)) } - } - } else { - // End the iterator - None + }; + + self.kmer_index += 1; // Advance k-mer index for valid k-mers + return Some(result); } + None // No more frames or k-mers } } @@ -920,6 +1023,7 @@ impl TryInto for Signature { #[cfg(test)] mod test { + use std::fs::File; use std::io::{BufReader, Read}; use std::path::PathBuf; @@ -927,7 +1031,8 @@ mod test { use needletail::parse_fastx_reader; use crate::cmd::ComputeParameters; - use crate::signature::SigsTrait; + use crate::encodings::HashFunctions; + use crate::signature::{ReadingFrame, SeqToHashes, SigsTrait}; use super::Signature; @@ -1049,6 +1154,98 @@ mod test { assert_eq!(sig.signatures[1].size(), 2); } + #[test] + fn signature_skipm2n3_add_sequence() { + let params = ComputeParameters::builder() + .ksizes(vec![3, 4, 5, 6]) + .num_hashes(3u32) + .dna(false) + .skipm2n3(true) + .build(); + + let mut sig = Signature::from_params(¶ms); + sig.add_sequence(b"ATGCATGA", false).unwrap(); + + assert_eq!(sig.signatures.len(), 4); + dbg!(&sig.signatures); + assert_eq!(sig.signatures[0].size(), 3); + assert_eq!(sig.signatures[1].size(), 3); + eprintln!("{:?}", sig.signatures[2]); + assert_eq!(sig.signatures[2].size(), 3); + assert_eq!(sig.signatures[3].size(), 1); + } + + #[test] + fn signature_skipm1n3_add_sequence() { + let params = ComputeParameters::builder() + .ksizes(vec![3, 4, 5, 6]) + .num_hashes(10u32) + .dna(false) + .skipm1n3(true) + .build(); + + let mut sig = Signature::from_params(¶ms); + sig.add_sequence(b"ATGCATGAATGAC", false).unwrap(); + + assert_eq!(sig.signatures.len(), 4); + dbg!(&sig.signatures); + assert_eq!(sig.signatures[0].size(), 5); + assert_eq!(sig.signatures[1].size(), 4); + assert_eq!(sig.signatures[2].size(), 1); + assert_eq!(sig.signatures[3].size(), 0); + } + + #[test] + fn signature_skipm2n3_add_sequence_too_small() { + let ksize = 2; + let params = ComputeParameters::builder() + .ksizes(vec![ksize]) + .num_hashes(10u32) + .dna(false) + .skipm2n3(true) + .build(); + + let mut sig = Signature::from_params(¶ms); + let result = sig.add_sequence(b"ATGCATGA", false); + + match result { + Err(error) => { + // Convert the error to a string and check the message + let error_message = format!("{}", error); + assert_eq!( + error_message, + "Skipmer ksize must be >= n (3), but got ksize: 2" + ); + } + _ => panic!("Expected SourmashError::InvalidSkipmerSize"), + } + } + + #[test] + fn signature_skipm1n3_add_sequence_too_small() { + let params = ComputeParameters::builder() + .ksizes(vec![2]) + .num_hashes(10u32) + .dna(false) + .skipm1n3(true) + .build(); + + let mut sig = Signature::from_params(¶ms); + let result = sig.add_sequence(b"ATGCATGA", false); + + match result { + Err(error) => { + // Convert the error to a string and check the message + let error_message = format!("{}", error); + assert_eq!( + error_message, + "Skipmer ksize must be >= n (3), but got ksize: 2" + ); + } + _ => panic!("Expected SourmashError::InvalidSkipmerSize"), + } + } + #[test] fn signature_add_sequence_cp() { let mut cp = ComputeParameters::default(); @@ -1287,4 +1484,592 @@ mod test { assert_eq!(modified_sig.size(), 0); } } + + #[test] + fn test_readingframe_dna() { + let sequence = b"AGTCGT"; + let frame = ReadingFrame::new_dna(sequence); + + assert_eq!(frame.fw(), sequence.as_slice()); + assert_eq!(frame.rc(), b"ACGACT".as_slice()); + } + + #[test] + fn test_fw_dna() { + let dna_frame = ReadingFrame::DNA { + fw: b"ATCG".to_vec(), + rc: b"CGAT".to_vec(), + len: 4, + }; + assert_eq!(dna_frame.fw(), b"ATCG"); + } + + #[test] + fn test_rc_dna() { + let dna_frame = ReadingFrame::DNA { + fw: b"ATCG".to_vec(), + rc: b"CGAT".to_vec(), + len: 4, + }; + assert_eq!(dna_frame.rc(), b"CGAT"); + } + + #[test] + fn test_length_dna() { + let dna_frame = ReadingFrame::DNA { + fw: b"ATCG".to_vec(), + rc: b"CGAT".to_vec(), + len: 4, + }; + assert_eq!(dna_frame.length(), 4); + } + + #[test] + fn test_frame_type_dna() { + let dna_frame = ReadingFrame::DNA { + fw: b"ATCG".to_vec(), + rc: b"CGAT".to_vec(), + len: 4, + }; + assert_eq!(dna_frame.frame_type(), "DNA"); + } + + #[test] + fn test_fw_protein() { + let protein_frame = ReadingFrame::Protein { + fw: b"MVHL".to_vec(), + len: 4, + }; + assert_eq!(protein_frame.fw(), b"MVHL"); + } + + #[test] + #[should_panic(expected = "Reverse complement is only available for DNA frames")] + fn test_rc_protein_panics() { + let protein_frame = ReadingFrame::Protein { + fw: b"MVHL".to_vec(), + len: 4, + }; + protein_frame.rc(); + } + + #[test] + fn test_length_protein() { + let protein_frame = ReadingFrame::Protein { + fw: b"MVHL".to_vec(), + len: 4, + }; + assert_eq!(protein_frame.length(), 4); + } + + #[test] + fn test_frame_type_protein() { + let protein_frame = ReadingFrame::Protein { + fw: b"MVHL".to_vec(), + len: 4, + }; + assert_eq!(protein_frame.frame_type(), "Protein"); + } + + #[test] + fn test_readingframe_display_protein() { + // Create a Protein ReadingFrame + let protein_frame = ReadingFrame::Protein { + fw: b"MVHLK".to_vec(), + len: 5, + }; + + let output = format!("{}", protein_frame); + // Assert the output matches the expected format + assert_eq!(output, "Type: Protein (5aa), Forward: MVHLK"); + } + + #[test] + fn test_seqtohashes_frames_dna() { + let sequence = b"AGTCGT"; + let hash_function = HashFunctions::Murmur64Dna; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames.clone(); + + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].fw(), sequence.as_slice()); + assert_eq!(frames[0].rc(), b"ACGACT".as_slice()); + } + + #[test] + fn test_seqtohashes_frames_is_protein() { + let sequence = b"MVLSPADKTNVKAAW"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames.clone(); + + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].fw(), sequence.as_slice()); + } + + #[test] + fn test_readingframe_protein() { + let sequence = b"MVLSPADKTNVKAAW"; + let hash_function = HashFunctions::Murmur64Protein; + let frame = + ReadingFrame::new_protein(sequence, hash_function.dayhoff(), hash_function.hp()); + + assert_eq!(frame.fw(), sequence.as_slice()); + } + + #[test] + #[should_panic] + fn test_seqtohashes_frames_is_protein_try_access_rc() { + // test panic if trying to access rc + let sequence = b"MVLSPADKTNVKAAW"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames.clone(); + + // protein frame doesn't have rc; this should panic + eprintln!("{:?}", frames[0].rc()); + } + + #[test] + fn test_seqtohashes_frames_is_protein_dayhoff() { + let sequence = b"MVLSPADKTNVKAAW"; + let dayhoff_seq = b"eeebbbcdbcedbbf"; + let hash_function = HashFunctions::Murmur64Dayhoff; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames.clone(); + + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].fw(), dayhoff_seq.as_slice()); + } + + #[test] + fn test_seqtohashes_frames_is_protein_hp() { + let sequence = b"MVLSPADKTNVKAAW"; + let hp_seq = b"hhhphhpppphphhh"; + let hash_function = HashFunctions::Murmur64Hp; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames.clone(); + + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].fw(), hp_seq.as_slice()); + } + + #[test] + fn test_seqtohashes_frames_translate_protein() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames.clone(); + + assert_eq!(frames[0].fw(), b"SRRA".as_slice()); + assert_eq!(frames[1].fw(), b"SSTT".as_slice()); + assert_eq!(frames[2].fw(), b"VVE".as_slice()); + assert_eq!(frames[3].fw(), b"ARR".as_slice()); + assert_eq!(frames[4].fw(), b"SSS".as_slice()); + assert_eq!(frames[5].fw(), b"LDD".as_slice()); + } + + #[test] + fn test_readingframe_translate() { + let sequence = b"AGTCGT"; + let frame_start = 3; // four frames but translate can only + + let result = ReadingFrame::new_translated(sequence, frame_start, false, false); + + match result { + Err(error) => { + // Convert the error to a string and check the message + let error_message = format!("{}", error); + assert_eq!(error_message, "Frame number must be 0, 1, or 2, but got 3"); + } + _ => panic!("Expected SourmashError::InvalidTranslateFrame"), + } + } + + #[test] + fn test_readingframe_skipmer() { + let sequence = b"AGTCGT"; + let m = 2; + let n = 3; + let num_frames = 4; // four frames but n is only 3 + + let result = ReadingFrame::new_skipmer(sequence, num_frames, m, n); + + match result { + Err(error) => { + // Convert the error to a string and check the message + let error_message = format!("{}", error); + assert_eq!( + error_message, + "Skipmer frame number must be < n (3), but got start: 4" + ); + } + _ => panic!("Expected SourmashError::InvalidSkipmerFrame"), + } + } + + #[test] + fn test_seqtohashes_frames_skipmer_m1n3() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Skipm1n3; // Represents m=1, n=3 + let k_size = 3; // K-mer size is not directly relevant for skipmer frame validation + let seed = 42; // Seed is also irrelevant for frame structure + let force = false; + let is_protein = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames.clone(); + + eprintln!("Frames: {:?}", frames); + + assert_eq!(frames.len(), 3); // Three skipmer frames + + // Expected skipmer sequences for m=1, n=3 (keep-1, skip-2) + assert_eq!(frames[0].fw(), b"ACCG".as_slice()); + assert_eq!(frames[0].rc(), b"CGGT".as_slice()); + + assert_eq!(frames[1].fw(), b"GGGC".as_slice()); + assert_eq!(frames[1].rc(), b"GCCC".as_slice()); + + assert_eq!(frames[2].fw(), b"TTAT".as_slice()); + assert_eq!(frames[2].rc(), b"ATAA".as_slice()); + } + + #[test] + fn test_seqtohashes_frames_skipmer_m2n3() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Skipm2n3; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames; + eprintln!("Frames: {:?}", frames); + + assert_eq!(frames.len(), 3); // Three skipmer frames + + // Expected skipmer sequences for m=1, n=3 (keep-1, skip-2) + assert_eq!(frames[0].fw(), b"AGCGCGGC".as_slice()); + assert_eq!(frames[0].rc(), b"GCCGCGCT".as_slice()); + + assert_eq!(frames[1].fw(), b"GTGTGACT".as_slice()); + assert_eq!(frames[1].rc(), b"AGTCACAC".as_slice()); + + assert_eq!(frames[2].fw(), b"TCTCAGT".as_slice()); + assert_eq!(frames[2].rc(), b"ACTGAGA".as_slice()); + } + + #[test] + fn test_seqtohashes_dna() { + let sequence = b"AGTCGT"; + let hash_function = HashFunctions::Murmur64Dna; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + + // Expected k-mers from the forward and reverse complement sequence + let expected_kmers = vec![ + (b"AGT".to_vec(), b"ACT".to_vec()), + (b"GTC".to_vec(), b"GAC".to_vec()), + (b"TCG".to_vec(), b"CGA".to_vec()), + (b"CGT".to_vec(), b"ACG".to_vec()), + ]; + + // Compute expected hashes from expected kmers + let expected_hashes: Vec = expected_kmers + .iter() + .map(|(fw_kmer, rc_kmer)| crate::_hash_murmur(std::cmp::min(fw_kmer, rc_kmer), seed)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); + } + + #[test] + fn test_seqtohashes_dna_2() { + let sequence = b"AGTCGTCA"; + let k_size = 7; + let seed = 42; + let force = true; // Force skip over invalid bases if needed + let is_protein = false; + // Initialize SeqToHashes iterator using the new constructor + let mut seq_to_hashes = SeqToHashes::new( + sequence, + k_size, + force, + is_protein, + HashFunctions::Murmur64Dna, + seed, + ) + .unwrap(); + + // Define expected hashes for the kmer configuration. + let expected_kmers = ["AGTCGTC", "GTCGTCA"]; + let expected_krc = ["GACGACT", "TGACGAC"]; + + // Compute expected hashes by hashing each k-mer with its reverse complement + let expected_hashes: Vec = expected_kmers + .iter() + .zip(expected_krc.iter()) + .map(|(kmer, krc)| { + // Convert both kmer and krc to byte slices and pass to _hash_murmur + crate::_hash_murmur(std::cmp::min(kmer.as_bytes(), krc.as_bytes()), seed) + }) + .collect(); + + // Compare each produced hash from the iterator with the expected hash + for expected_hash in expected_hashes { + let hash = seq_to_hashes.next().unwrap().ok().unwrap(); + assert_eq!(hash, expected_hash, "Mismatch in DNA hash"); + } + } + + #[test] + fn test_seqtohashes_is_protein() { + let sequence = b"MVLSPADKTNVKAAW"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 3; + let seed = 42; + let force = false; + let is_protein = true; + + let sth = + SeqToHashes::new(sequence, k_size * 3, force, is_protein, hash_function, seed).unwrap(); + + // Expected k-mers for protein sequence + let expected_kmers = vec![ + b"MVL".to_vec(), + b"VLS".to_vec(), + b"LSP".to_vec(), + b"SPA".to_vec(), + b"PAD".to_vec(), + b"ADK".to_vec(), + b"DKT".to_vec(), + b"KTN".to_vec(), + b"TNV".to_vec(), + b"NVK".to_vec(), + b"VKA".to_vec(), + b"KAA".to_vec(), + b"AAW".to_vec(), + ]; + + // Compute hashes for expected k-mers + let expected_hashes: Vec = expected_kmers + .iter() + .map(|fw_kmer| crate::_hash_murmur(fw_kmer, 42)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!(sth_hashes, expected_hashes, "Hashes do not match in order"); + } + + #[test] + fn test_seqtohashes_translate() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Protein; + let k_size = 9; // needs to be *3 for protein + let seed = 42; + let force = false; + let is_protein = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + + let expected_kmers = vec![ + b"SRR".as_slice(), + b"RRA".as_slice(), + b"SST".as_slice(), + b"STT".as_slice(), + b"VVE".as_slice(), + b"ARR".as_slice(), + b"SSS".as_slice(), + b"LDD".as_slice(), + ]; + + // Compute expected hashes + let expected_hashes: Vec = expected_kmers + .iter() + .map(|fw_kmer| crate::_hash_murmur(fw_kmer, seed)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); + } + + #[test] + fn test_seqtohashes_skipm1n3() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Skipm1n3; + let k_size = 3; + let is_protein = false; + let seed = 42; + let force = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + // Expected k-mers for skipmer (m=1, n=3) across all frames + let expected_kmers = vec![ + (b"ACC".as_slice(), b"GGT".as_slice()), + (b"CCG".as_slice(), b"CGG".as_slice()), + (b"GGG".as_slice(), b"CCC".as_slice()), + (b"GGC".as_slice(), b"GCC".as_slice()), + (b"TTA".as_slice(), b"TAA".as_slice()), + (b"TAT".as_slice(), b"ATA".as_slice()), + ]; + + // Compute expected hashes + let expected_hashes: Vec = expected_kmers + .iter() + .map(|(fw_kmer, rc_kmer)| crate::_hash_murmur(std::cmp::min(fw_kmer, rc_kmer), seed)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); + } + + #[test] + fn test_seq2hashes_skipm2n3() { + let sequence = b"AGTCGTCGAGCT"; + let hash_function = HashFunctions::Murmur64Skipm2n3; + let k_size = 7; + let is_protein = false; + let seed = 42; + let force = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + + // Expected k-mers for skipmer (m=2, n=3) + let expected_kmers = vec![ + (b"AGCGCGG".as_slice(), b"CCGCGCT".as_slice()), + (b"GCGCGGC".as_slice(), b"GCCGCGC".as_slice()), + (b"GTGTGAC".as_slice(), b"GTCACAC".as_slice()), + (b"TGTGACT".as_slice(), b"AGTCACA".as_slice()), + (b"TCTCAGT".as_slice(), b"ACTGAGA".as_slice()), + ]; + + // Compute expected hashes + let expected_hashes: Vec = expected_kmers + .iter() + .map(|(fw_kmer, rc_kmer)| crate::_hash_murmur(std::cmp::min(fw_kmer, rc_kmer), seed)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); + } + + #[test] + fn test_seqtohashes_skipm2n3_2() { + let sequence = b"AGTCGTCA"; + let hash_function = HashFunctions::Murmur64Skipm2n3; + let k_size = 5; + let seed = 42; + let force = true; + let is_protein = false; + + let sth = + SeqToHashes::new(sequence, k_size, force, is_protein, hash_function, seed).unwrap(); + let frames = sth.frames.clone(); + for fr in frames { + eprintln!("{}", fr); + } + + let expected_kmers = vec![ + (b"AGCGC".as_slice(), b"GCGCT".as_slice()), + (b"GCGCA".as_slice(), b"TGCGC".as_slice()), + (b"GTGTA".as_slice(), b"TACAC".as_slice()), + ]; + + // Compute expected hashes + let expected_hashes: Vec = expected_kmers + .iter() + .map(|(fw_kmer, rc_kmer)| crate::_hash_murmur(std::cmp::min(fw_kmer, rc_kmer), seed)) + .collect(); + + // Collect hashes from SeqToHashes + let sth_hashes: Vec = sth.map(|result| result.unwrap()).collect(); + eprintln!("SeqToHashes hashes: {:?}", sth_hashes); + + // Check that SeqToHashes matches expected hashes in order + assert_eq!( + sth_hashes, expected_hashes, + "Hashes do not match in order for SeqToHashes" + ); + } } diff --git a/src/core/src/sketch/minhash.rs b/src/core/src/sketch/minhash.rs index f8db721465..43f8b602ea 100644 --- a/src/core/src/sketch/minhash.rs +++ b/src/core/src/sketch/minhash.rs @@ -155,6 +155,8 @@ impl<'de> Deserialize<'de> for KmerMinHash { "dayhoff" => HashFunctions::Murmur64Dayhoff, "hp" => HashFunctions::Murmur64Hp, "dna" => HashFunctions::Murmur64Dna, + "skipm1n3" => HashFunctions::Murmur64Skipm1n3, + "skipm2n3" => HashFunctions::Murmur64Skipm2n3, _ => unimplemented!(), // TODO: throw error here }; diff --git a/src/core/tests/minhash.rs b/src/core/tests/minhash.rs index 59eddeff4a..82fd361acb 100644 --- a/src/core/tests/minhash.rs +++ b/src/core/tests/minhash.rs @@ -754,7 +754,9 @@ fn seq_to_hashes(seq in "ACGTGTAGCTAGACACTGACTGACTGAC") { let mut hashes: Vec = Vec::new(); - for hash_value in SeqToHashes::new(seq.as_bytes(), mh.ksize(), false, false, mh.hash_function(), mh.seed()){ + let ready_hashes = SeqToHashes::new(seq.as_bytes(), mh.ksize(), false, false, mh.hash_function(), mh.seed())?; + + for hash_value in ready_hashes{ match hash_value{ Ok(0) => continue, Ok(x) => hashes.push(x), @@ -777,7 +779,9 @@ fn seq_to_hashes_2(seq in "QRMTHINK") { let mut hashes: Vec = Vec::new(); - for hash_value in SeqToHashes::new(seq.as_bytes(), mh.ksize(), false, true, mh.hash_function(), mh.seed()){ + let ready_hashes = SeqToHashes::new(seq.as_bytes(), mh.ksize(), false, true, mh.hash_function(), mh.seed())?; + + for hash_value in ready_hashes { match hash_value{ Ok(0) => continue, Ok(x) => hashes.push(x), diff --git a/tests/test_sourmash_sketch.py b/tests/test_sourmash_sketch.py index 88332b1945..a9af0fdd24 100644 --- a/tests/test_sourmash_sketch.py +++ b/tests/test_sourmash_sketch.py @@ -345,6 +345,7 @@ def test_protein_override_bad_rust_foo(): siglist = factory() assert len(siglist) == 1 sig = siglist[0] + print(sig.minhash.ksize) # try adding something testdata1 = utils.get_test_data("ecoli.faa") @@ -354,7 +355,7 @@ def test_protein_override_bad_rust_foo(): with pytest.raises(ValueError) as exc: sig.add_protein(record.sequence) - assert 'Invalid hash function: "DNA"' in str(exc) + assert "invalid DNA character in input k-mer: MRVLKFGGTS" in str(exc) def test_dayhoff_defaults():