From 2c898d793a48fb01394242f8d27387066c81ac2b Mon Sep 17 00:00:00 2001 From: Tessa Pierce Ward Date: Thu, 24 Oct 2024 11:48:57 -0700 Subject: [PATCH] MRG: remove `BuildParams`, filter via manifest / `Select` approaches (#127) This PR replaces the `BuildParams` manual hashing + filtering with manifest filtering and selection (e.g. of moltypes) via a `Select` framework. This is more in tune with the approaches in sourmash core. I did add `MultiSelect` here, since we want to keep all templates that match sets of selection parameters. Internal improvements to the framework introduced in 0.4.0: - All of the parameter parsing + checks from `BuildParams`/`BuildParamsSet` is now in `BuildRecord`/`BuildCollection`. We now parse the parameter string directly into a `BuildCollection`, rather than going through `BuildParams` as an intermediary. - To manage finding and filtering existing signatures, we now select on the `BuildCollection` directly via `MultiSelect`, e.g. for moltype filtering. We can also filter a manifest with another manifest by using PartialEq for `BuildRecord`s. This replaces the prior approach of keeping a `BuildParamsSet` and hashing `BuildParams` manually for comparisons. - We continue to use sourmash core `ComputeParameters` to actually build template signatures. - Instead of handling DNA, protein as separate collections, manage them jointly by allowing selection on moltype and adding to DNA sigs when we have DNA, prot sigs when we have proteins. This reduces complexity of use and has the added benefit of easier addition of translated sigs if we want to add that in the future. - Fixes #113 --- src/directsketch.rs | 194 ++++-- src/utils/buildutils.rs | 1467 +++++++++++++++++++++------------------ src/utils/mod.rs | 84 +-- tests/test_gbsketch.py | 4 +- 4 files changed, 933 insertions(+), 816 deletions(-) diff --git a/src/directsketch.rs b/src/directsketch.rs index 80ad356..cbac395 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -5,7 +5,7 @@ use regex::Regex; use reqwest::Client; use sourmash::collection::Collection; use std::cmp::max; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fs::{self, create_dir_all}; use std::panic; use std::sync::atomic::{AtomicBool, Ordering}; @@ -23,7 +23,7 @@ use crate::utils::{ }; use crate::utils::buildutils::{ - BuildCollection, BuildManifest, BuildParamsSet, MultiBuildCollection, + BuildCollection, BuildManifest, MultiBuildCollection, MultiSelect, MultiSelection, }; use reqwest::Url; @@ -237,8 +237,7 @@ async fn dl_sketch_assembly_accession( location: &PathBuf, retry: Option, keep_fastas: bool, - dna_sigs: &mut BuildCollection, - prot_sigs: &mut BuildCollection, + sigs: &mut BuildCollection, genomes_only: bool, proteomes_only: bool, download_only: bool, @@ -375,22 +374,23 @@ async fn dl_sketch_assembly_accession( // sketch data match file_type { GenBankFileType::Genomic => { - dna_sigs.build_sigs_from_data(data, "dna", name.clone(), file_name.clone())?; - built_sigs.add_collection(dna_sigs); + sigs.build_sigs_from_data(data, "DNA", name.clone(), file_name.clone())?; } GenBankFileType::Protein => { - prot_sigs.build_sigs_from_data( - data, - "protein", - name.clone(), - file_name.clone(), - )?; - built_sigs.add_collection(prot_sigs); + sigs.build_sigs_from_data(data, "protein", name.clone(), file_name.clone())?; } _ => {} // Do nothing for other file types }; } } + if !download_only { + // remove any template sigs that were not populated + sigs.filter_empty(); + // to do: can we use sigs directly rather than adding to a multibuildcollection, now? + if !sigs.is_empty() { + built_sigs.add_collection(sigs); + } + } Ok((built_sigs, download_failures, checksum_failures)) } @@ -402,8 +402,7 @@ async fn dl_sketch_url( location: &PathBuf, retry: Option, _keep_fastas: bool, - dna_sigs: &mut BuildCollection, - prot_sigs: &mut BuildCollection, + sigs: &mut BuildCollection, _genomes_only: bool, _proteomes_only: bool, download_only: bool, @@ -434,26 +433,21 @@ async fn dl_sketch_url( if !download_only { let filename = download_filename.clone().unwrap_or("".to_string()); // sketch data + match moltype { InputMolType::Dna => { - dna_sigs.build_sigs_from_data( - data, - "dna", - name.clone(), - filename.clone(), - )?; - built_sigs.add_collection(dna_sigs); + sigs.build_sigs_from_data(data, "DNA", name.clone(), filename.clone())?; } InputMolType::Protein => { - prot_sigs.build_sigs_from_data( - data, - "protein", - name.clone(), - filename.clone(), - )?; - built_sigs.add_collection(prot_sigs); + sigs.build_sigs_from_data(data, "protein", name.clone(), filename.clone())?; } }; + // remove any template sigs that were not populated + sigs.filter_empty(); + // to do: can we use sigs directly rather than adding to a collection, now? + if !sigs.is_empty() { + built_sigs.add_collection(sigs); + } } } Err(err) => { @@ -512,6 +506,7 @@ async fn load_existing_zip_batches(outpath: &PathBuf) -> Result<(MultiCollection let mut collections = Vec::new(); let mut highest_batch = 0; // Track the highest batch number + // find parent dir (or use current dir) let dir = outpath_base .parent() .filter(|p| !p.as_os_str().is_empty()) // Ensure the parent is not empty @@ -893,7 +888,7 @@ pub async fn gbsketch( ) -> Result<(), anyhow::Error> { let batch_size = batch_size as usize; let mut batch_index = 1; - let mut name_params_map: HashMap> = HashMap::new(); + let mut existing_records_map: HashMap = HashMap::new(); let mut filter = false; if let Some(ref output_sigs) = output_sigs { // Create outpath from output_sigs @@ -907,7 +902,7 @@ pub async fn gbsketch( let (existing_sigs, max_existing_batch_index) = load_existing_zip_batches(&outpath).await?; // Check if there are any existing batches to process if !existing_sigs.is_empty() { - name_params_map = existing_sigs.buildparams_hashmap(); + existing_records_map = existing_sigs.build_recordsmap(); batch_index = max_existing_batch_index + 1; eprintln!( @@ -968,39 +963,41 @@ pub async fn gbsketch( bail!("No accessions to download and sketch.") } - // parse param string into params_vec, print error if fail - let param_result = BuildParamsSet::from_params_str(param_str); - let params_set = match param_result { - Ok(params) => params, + let sig_template_result = BuildCollection::from_param_str(param_str.as_str()); + let mut sig_templates = match sig_template_result { + Ok(sig_templates) => sig_templates, Err(e) => { bail!("Failed to parse params string: {}", e); } }; - // Use the BuildParamsSet to create template collections for DNA and protein - let dna_template_collection = BuildCollection::from_buildparams_set(¶ms_set, "DNA"); - // // prot will build protein, dayhoff, hp - let prot_template_collection = BuildCollection::from_buildparams_set(¶ms_set, "protein"); let mut genomes_only = genomes_only; let mut proteomes_only = proteomes_only; - // Check if dna_sig_templates is empty and not keep_fastas - if dna_template_collection.manifest.is_empty() && !keep_fastas { + // Check if we have dna signature templates and not keep_fastas + if sig_templates.dna_size()? == 0 && !keep_fastas { eprintln!("No DNA signature templates provided, and --keep-fasta is not set."); proteomes_only = true; } - // Check if protein_sig_templates is empty and not keep_fastas - if prot_template_collection.manifest.is_empty() && !keep_fastas { + // Check if we have protein signature templates not keep_fastas + if sig_templates.anyprotein_size()? == 0 && !keep_fastas { eprintln!("No protein signature templates provided, and --keep-fasta is not set."); genomes_only = true; } if genomes_only { + // select only DNA templates + let multiselection = MultiSelection::from_moltypes(vec!["dna"])?; + sig_templates = sig_templates.select(&multiselection)?; + if !download_only { eprintln!("Downloading and sketching genomes only."); } else { eprintln!("Downloading genomes only."); } } else if proteomes_only { + // select only protein templates + let multiselection = MultiSelection::from_moltypes(vec!["protein", "dayhoff", "hp"])?; + sig_templates = sig_templates.select(&multiselection)?; if !download_only { eprintln!("Downloading and sketching proteomes only."); } else { @@ -1008,21 +1005,23 @@ pub async fn gbsketch( } } + if sig_templates.is_empty() && !download_only { + bail!("No signatures to build.") + } + // report every 1 percent (or every 1, whichever is larger) let reporting_threshold = std::cmp::max(n_accs / 100, 1); for (i, accinfo) in accession_info.into_iter().enumerate() { py.check_signals()?; // If interrupted, return an Err automatically - let mut dna_sigs = dna_template_collection.clone(); - let mut prot_sigs = prot_template_collection.clone(); + let mut sigs = sig_templates.clone(); // filter template sigs based on existing sigs if filter { - if let Some(existing_paramset) = name_params_map.get(&accinfo.name) { + if let Some(existing_manifest) = existing_records_map.get(&accinfo.name) { // If the key exists, filter template sigs - dna_sigs.filter(existing_paramset); - prot_sigs.filter(existing_paramset); + sigs.filter_by_manifest(existing_manifest); } } @@ -1054,8 +1053,7 @@ pub async fn gbsketch( &download_path_clone, Some(retry_times), keep_fastas, - &mut dna_sigs, - &mut prot_sigs, + &mut sigs, genomes_only, proteomes_only, download_only, @@ -1126,7 +1124,7 @@ pub async fn urlsketch( ) -> Result<(), anyhow::Error> { let batch_size = batch_size as usize; let mut batch_index = 1; - let mut name_params_map: HashMap> = HashMap::new(); + let mut existing_recordsmap: HashMap = HashMap::new(); let mut filter = false; if let Some(ref output_sigs) = output_sigs { // Create outpath from output_sigs @@ -1140,7 +1138,7 @@ pub async fn urlsketch( let (existing_sigs, max_existing_batch_index) = load_existing_zip_batches(&outpath).await?; // Check if there are any existing batches to process if !existing_sigs.is_empty() { - name_params_map = existing_sigs.buildparams_hashmap(); + existing_recordsmap = existing_sigs.build_recordsmap(); batch_index = max_existing_batch_index + 1; eprintln!( @@ -1208,38 +1206,41 @@ pub async fn urlsketch( bail!("No accessions to download and sketch.") } - // parse param string into params_vec, print error if fail - let param_result = BuildParamsSet::from_params_str(param_str); - let params_set = match param_result { - Ok(params) => params, + let sig_template_result = BuildCollection::from_param_str(param_str.as_str()); + let mut sig_templates = match sig_template_result { + Ok(sig_templates) => sig_templates, Err(e) => { bail!("Failed to parse params string: {}", e); } }; - // Use the BuildParamsSet to create template collections for DNA and protein - let dna_template_collection = BuildCollection::from_buildparams_set(¶ms_set, "DNA"); - let prot_template_collection = BuildCollection::from_buildparams_set(¶ms_set, "protein"); let mut genomes_only = false; let mut proteomes_only = false; - // Check if dna_sig_templates is empty and not keep_fastas - if dna_template_collection.manifest.is_empty() && !keep_fastas { - eprintln!("No DNA signature templates provided, and --keep-fastas is not set."); + // Check if we have dna signature templates and not keep_fastas + if sig_templates.dna_size()? == 0 && !keep_fastas { + eprintln!("No DNA signature templates provided, and --keep-fasta is not set."); proteomes_only = true; } - // Check if protein_sig_templates is empty and not keep_fastas - if prot_template_collection.manifest.is_empty() && !keep_fastas { - eprintln!("No protein signature templates provided, and --keep-fastas is not set."); + // Check if we have protein signature templates not keep_fastas + if sig_templates.anyprotein_size()? == 0 && !keep_fastas { + eprintln!("No protein signature templates provided, and --keep-fasta is not set."); genomes_only = true; } if genomes_only { + // select only DNA templates + let multiselection = MultiSelection::from_moltypes(vec!["dna"])?; + sig_templates = sig_templates.select(&multiselection)?; + if !download_only { eprintln!("Downloading and sketching genomes only."); } else { eprintln!("Downloading genomes only."); } } else if proteomes_only { + // select only protein templates + let multiselection = MultiSelection::from_moltypes(vec!["protein", "dayhoff", "hp"])?; + sig_templates = sig_templates.select(&multiselection)?; if !download_only { eprintln!("Downloading and sketching proteomes only."); } else { @@ -1247,20 +1248,22 @@ pub async fn urlsketch( } } + if sig_templates.is_empty() && !download_only { + bail!("No signatures to build.") + } + // report every 1 percent (or every 1, whichever is larger) let reporting_threshold = std::cmp::max(n_accs / 100, 1); for (i, accinfo) in accession_info.into_iter().enumerate() { py.check_signals()?; // If interrupted, return an Err automatically - let mut dna_sigs = dna_template_collection.clone(); - let mut prot_sigs = prot_template_collection.clone(); + let mut sigs = sig_templates.clone(); // filter template sigs based on existing sigs if filter { - if let Some(existing_paramset) = name_params_map.get(&accinfo.name) { + if let Some(existing_manifest) = existing_recordsmap.get(&accinfo.name) { // If the key exists, filter template sigs - dna_sigs.filter(existing_paramset); - prot_sigs.filter(existing_paramset); + sigs.filter_by_manifest(existing_manifest); } } @@ -1291,8 +1294,7 @@ pub async fn urlsketch( &download_path_clone, Some(retry_times), keep_fastas, - &mut dna_sigs, - &mut prot_sigs, + &mut sigs, genomes_only, proteomes_only, download_only, @@ -1363,3 +1365,51 @@ pub async fn urlsketch( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::buildutils::BuildRecord; + use camino::Utf8PathBuf; + + #[test] + fn test_buildrecordsmap() { + // read in zipfiles to build a MultiCollection + let mut filename = Utf8PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/GCA_000961135.2.sig.zip"); + let path = filename.clone(); + + let mut collections = Vec::new(); + let coll: Collection = Collection::from_zipfile(&path).unwrap(); + collections.push(coll); + let mc = MultiCollection::new(collections); + + // build expected buildmanifest + let mut refbmf = BuildManifest::new(); + let mut rec1 = BuildRecord::default_dna(); + rec1.set_with_abundance(true); + refbmf.add_record(rec1); + + // Call build_recordsmap + let name_params_map = mc.build_recordsmap(); + + // Check that the recordsmap contains the correct names + assert_eq!( + name_params_map.len(), + 1, + "There should be 1 unique names in the map" + ); + + for (name, buildmanifest) in name_params_map.iter() { + eprintln!("Name: {}", name); + assert_eq!( + "GCA_000961135.2 Candidatus Aramenus sulfurataquae isolate AZ1-454", + name + ); + assert_eq!(buildmanifest.size(), 2); // should be two records + // check that we can filter out a record (k=31, abund) + let filtered = buildmanifest.filter_manifest(&refbmf); + assert_eq!(filtered.size(), 1) + } + } +} diff --git a/src/utils/buildutils.rs b/src/utils/buildutils.rs index a3953ab..07acd8c 100644 --- a/src/utils/buildutils.rs +++ b/src/utils/buildutils.rs @@ -10,278 +10,59 @@ use needletail::parser::SequenceRecord; use needletail::{parse_fastx_file, parse_fastx_reader}; use serde::Serialize; use sourmash::cmd::ComputeParameters; -use sourmash::encodings::HashFunctions; +use sourmash::encodings::{HashFunctions, Idx}; +use sourmash::errors::SourmashError; use sourmash::manifest::Record; +use sourmash::selection::Selection; use sourmash::signature::Signature; -use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; use std::collections::HashSet; use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::io::{Cursor, Write}; use std::num::ParseIntError; +use std::ops::Index; use std::str::FromStr; use tokio::fs::File; use tokio_util::compat::Compat; -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct BuildParams { - pub ksize: u32, - pub track_abundance: bool, - pub num: u32, - pub scaled: u64, - pub seed: u32, - pub moltype: HashFunctions, -} - -impl Default for BuildParams { - fn default() -> Self { - BuildParams { - ksize: 31, - track_abundance: false, - num: 0, - scaled: 1000, - seed: 42, - moltype: HashFunctions::Murmur64Dna, - } - } -} - -impl Hash for BuildParams { - fn hash(&self, state: &mut H) { - self.ksize.hash(state); - self.track_abundance.hash(state); - self.num.hash(state); - self.scaled.hash(state); - self.seed.hash(state); - self.moltype.hash(state); - } +#[derive(Default, Debug, Clone)] +pub struct MultiSelection { + pub selections: Vec, } -impl BuildParams { - pub fn calculate_hash(&self) -> u64 { - let mut hasher = DefaultHasher::new(); - self.hash(&mut hasher); - hasher.finish() - } - - pub fn from_record(record: &Record) -> Self { - let moltype = record.moltype(); // Get the moltype (HashFunctions enum) - - BuildParams { - ksize: record.ksize(), - track_abundance: record.with_abundance(), - num: *record.num(), - scaled: *record.scaled(), - seed: 42, - moltype, - } - } - - pub fn parse_ksize(value: &str) -> Result { - value - .parse::() - .map_err(|_| format!("cannot parse k='{}' as a valid integer", value)) - } - - // disallow repeated values for scaled, num, seed - pub fn parse_int_once( - value: &str, - field: &str, - current: &mut Option, - ) -> Result<(), String> - where - T: FromStr + Display + Copy, - { - let parsed_value = value - .parse::() - .map_err(|_| format!("cannot parse {}='{}' as a valid integer", field, value))?; - - // Check for conflicts; we don't allow multiple values for the same field. - if let Some(old_value) = *current { - return Err(format!( - "Conflicting values for '{}': {} and {}", - field, old_value, parsed_value - )); - } - - // Set the new value. - *current = Some(parsed_value); - - Ok(()) - } - - pub fn parse_moltype( - item: &str, - current: &mut Option, - ) -> Result { - let new_moltype = match item { - "protein" => HashFunctions::Murmur64Protein, - "dna" => HashFunctions::Murmur64Dna, - "dayhoff" => HashFunctions::Murmur64Dayhoff, - "hp" => HashFunctions::Murmur64Hp, - _ => return Err(format!("unknown moltype '{}'", item)), - }; - - // Check for conflicts and update the moltype. - if let Some(existing) = current { - if *existing != new_moltype { - return Err(format!( - "Conflicting moltype settings in param string: '{}' and '{}'", - existing, new_moltype - )); - } - } - - // Update the current value. - *current = Some(new_moltype.clone()); - - Ok(new_moltype) - } - - pub fn parse_abundance(item: &str, current: &mut Option) -> Result<(), String> { - let new_abundance = item == "abund"; - - if let Some(existing) = *current { - if existing != new_abundance { - return Err(format!( - "Conflicting abundance settings in param string: '{}'", - item - )); - } - } - - *current = Some(new_abundance); - Ok(()) - } - - pub fn from_param_string(p_str: &str) -> Result<(Self, Vec), String> { - let mut base_param = BuildParams::default(); - let mut ksizes = Vec::new(); - let mut moltype: Option = None; - let mut track_abundance: Option = None; - let mut num: Option = None; - let mut scaled: Option = None; - let mut seed: Option = None; - - for item in p_str.split(',') { - match item { - _ if item.starts_with("k=") => { - ksizes.push(Self::parse_ksize(&item[2..])?); - } - "abund" | "noabund" => { - Self::parse_abundance(item, &mut track_abundance)?; - } - "protein" | "dna" | "dayhoff" | "hp" => { - Self::parse_moltype(item, &mut moltype)?; - } - _ if item.starts_with("num=") => { - Self::parse_int_once(&item[4..], "num", &mut num)?; - } - _ if item.starts_with("scaled=") => { - Self::parse_int_once(&item[7..], "scaled", &mut scaled)?; - } - _ if item.starts_with("seed=") => { - Self::parse_int_once(&item[5..], "seed", &mut seed)?; - } - _ => return Err(format!("unknown component '{}' in params string", item)), - } - } - - // Ensure that num and scaled are mutually exclusive unless num is 0. - if let (Some(n), Some(_)) = (num, scaled) { - if n != 0 { - return Err("Cannot specify both 'num' (non-zero) and 'scaled' in the same parameter string".to_string()); - } - } - - // Apply parsed values to the base_param. - if let Some(moltype) = moltype { - base_param.moltype = moltype; - } - if let Some(track_abund) = track_abundance { - base_param.track_abundance = track_abund; - } - if let Some(n) = num { - base_param.num = n; - } - if let Some(s) = scaled { - base_param.scaled = s; - } - if let Some(s) = seed { - base_param.seed = s; - } - - if ksizes.is_empty() { - ksizes.push(base_param.ksize); // Use the default ksize if none were specified. +impl MultiSelection { + /// Create a `MultiSelection` from a single `Selection` + pub fn new(selection: Selection) -> Self { + MultiSelection { + selections: vec![selection], } - - Ok((base_param, ksizes)) } -} -#[derive(Debug)] -pub struct BuildParamsSet { - params: HashSet, -} + pub fn from_moltypes(moltypes: Vec<&str>) -> Result { + let selections: Result, SourmashError> = moltypes + .into_iter() + .map(|moltype_str| { + let moltype = HashFunctions::try_from(moltype_str)?; + let mut new_selection = Selection::default(); // Create a default Selection + new_selection.set_moltype(moltype); // Set the moltype + Ok(new_selection) + }) + .collect(); -impl Default for BuildParamsSet { - fn default() -> Self { - let mut set = HashSet::new(); - set.insert(BuildParams::default()); - BuildParamsSet { params: set } + Ok(MultiSelection { + selections: selections?, + }) } } -impl BuildParamsSet { - pub fn new() -> Self { - Self { - params: HashSet::new(), - } - } - - pub fn size(&self) -> usize { - self.params.len() - } - - pub fn insert(&mut self, param: BuildParams) { - self.params.insert(param); - } - - pub fn iter(&self) -> impl Iterator { - self.params.iter() - } - - pub fn from_params_str(params_str: String) -> Result { - if params_str.trim().is_empty() { - return Err("Parameter string cannot be empty.".to_string()); - } - - let mut set = BuildParamsSet::new(); - - for p_str in params_str.split('_') { - let (base_param, ksizes) = BuildParams::from_param_string(p_str)?; - - for k in ksizes { - let mut param = base_param.clone(); - param.ksize = k; - set.insert(param); - } - } - - Ok(set) - } - - pub fn get_params(&self) -> &HashSet { - &self.params - } - - pub fn into_vec(self) -> Vec { - self.params.into_iter().collect() - } +pub trait MultiSelect { + fn select(self, multi_selection: &MultiSelection) -> Result + where + Self: Sized; } -#[derive(Debug, Default, Clone, Getters, Setters, Serialize)] +#[derive(Debug, Clone, Getters, Setters, Serialize)] pub struct BuildRecord { // fields are ordered the same as Record to allow serialization to manifest // required fields are currently immutable once set @@ -318,8 +99,15 @@ pub struct BuildRecord { #[getset(get = "pub", set = "pub")] filename: Option, + #[getset(get_copy = "pub")] + #[serde(skip)] + pub seed: u32, + #[serde(skip)] pub hashed_params: u64, + + #[serde(skip)] + pub sequence_added: bool, } // from sourmash (intbool is currently private there) @@ -334,20 +122,124 @@ where } } +impl Default for BuildRecord { + fn default() -> Self { + // Default BuildRecord is DNA default + BuildRecord { + internal_location: None, + md5: None, + md5short: None, + ksize: 31, + moltype: "DNA".to_string(), + num: 0, + scaled: 1000, + n_hashes: None, + with_abundance: false, + name: None, + filename: None, + seed: 42, + hashed_params: 0, + sequence_added: false, + } + } +} + impl BuildRecord { - pub fn from_buildparams(param: &BuildParams) -> Self { - // Calculate the hash of Params - let hashed_params = param.calculate_hash(); + pub fn default_dna() -> Self { + Self { + ..Default::default() + } + } - BuildRecord { - ksize: param.ksize, - moltype: param.moltype.to_string(), - num: param.num, - scaled: param.scaled, - with_abundance: param.track_abundance, - hashed_params, - ..Default::default() // automatically set optional fields to None + pub fn default_protein() -> Self { + Self { + moltype: "protein".to_string(), + ksize: 10, + scaled: 200, + ..Default::default() + } + } + + pub fn default_dayhoff() -> Self { + Self { + moltype: "dayhoff".to_string(), + ksize: 10, + scaled: 200, + ..Default::default() + } + } + + pub fn default_hp() -> Self { + Self { + moltype: "hp".to_string(), + ksize: 10, + scaled: 200, + ..Default::default() + } + } + + pub fn moltype(&self) -> HashFunctions { + self.moltype.as_str().try_into().unwrap() + } + + pub fn from_record(record: &Record) -> Self { + Self { + ksize: record.ksize(), + moltype: record.moltype().to_string(), + num: *record.num(), + scaled: *record.scaled(), + with_abundance: record.with_abundance(), + ..Default::default() // ignore remaining fields + } + } + + pub fn matches_selection(&self, selection: &Selection) -> bool { + let mut valid = true; + + if let Some(ksize) = selection.ksize() { + valid = valid && self.ksize == ksize; + } + + if let Some(moltype) = selection.moltype() { + valid = valid && self.moltype() == moltype; + } + + if let Some(abund) = selection.abund() { + valid = valid && self.with_abundance == abund; + } + + if let Some(scaled) = selection.scaled() { + // num sigs have self.scaled = 0, don't include them + valid = valid && self.scaled != 0 && self.scaled <= scaled as u64; + } + + if let Some(num) = selection.num() { + valid = valid && self.num == num; } + + valid + } +} + +impl PartialEq for BuildRecord { + fn eq(&self, other: &Self) -> bool { + self.ksize == other.ksize + && self.moltype == other.moltype + && self.with_abundance == other.with_abundance + && self.num == other.num + && self.scaled == other.scaled + } +} + +impl Eq for BuildRecord {} + +impl Hash for BuildRecord { + fn hash(&self, state: &mut H) { + self.ksize.hash(state); + self.moltype.hash(state); + self.scaled.hash(state); + self.num.hash(state); + self.with_abundance.hash(state); } } @@ -371,11 +263,30 @@ impl BuildManifest { self.records.len() } + pub fn iter(&self) -> impl Iterator { + self.records.iter() + } + // clear all records pub fn clear(&mut self) { self.records.clear(); } + pub fn filter_manifest(&self, other: &BuildManifest) -> Self { + // Create a HashSet of references to the `BuildRecord`s in `other` + let pairs: HashSet<_> = other.records.iter().collect(); + + // Filter `self.records` to retain only those `BuildRecord`s that are NOT in `pairs` + let records = self + .records + .iter() + .filter(|&build_record| !pairs.contains(build_record)) + .cloned() + .collect(); + + Self { records } + } + pub fn add_record(&mut self, record: BuildRecord) { self.records.push(record); } @@ -432,6 +343,36 @@ impl BuildManifest { } } +impl MultiSelect for BuildManifest { + fn select(self, multi_selection: &MultiSelection) -> Result { + let rows = self.records.iter().filter(|row| { + // for each row, check if it matches any of the Selection structs in MultiSelection + multi_selection + .selections + .iter() + .any(|selection| row.matches_selection(selection)) + }); + + Ok(BuildManifest { + records: rows.cloned().collect(), + }) + } +} + +impl From> for BuildManifest { + fn from(records: Vec) -> Self { + BuildManifest { records } + } +} + +impl Index for BuildManifest { + type Output = BuildRecord; + + fn index(&self, index: usize) -> &Self::Output { + &self.records[index] + } +} + impl<'a> IntoIterator for &'a BuildManifest { type Item = &'a BuildRecord; type IntoIter = std::slice::Iter<'a, BuildRecord>; @@ -441,102 +382,307 @@ impl<'a> IntoIterator for &'a BuildManifest { } } -impl<'a> IntoIterator for &'a mut BuildManifest { - type Item = &'a mut BuildRecord; - type IntoIter = std::slice::IterMut<'a, BuildRecord>; +impl<'a> IntoIterator for &'a mut BuildManifest { + type Item = &'a mut BuildRecord; + type IntoIter = std::slice::IterMut<'a, BuildRecord>; + + fn into_iter(self) -> Self::IntoIter { + self.records.iter_mut() + } +} + +#[derive(Debug, Default, Clone)] +pub struct BuildCollection { + pub manifest: BuildManifest, + pub sigs: Vec, +} + +impl BuildCollection { + pub fn new() -> Self { + BuildCollection { + manifest: BuildManifest::new(), + sigs: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.manifest.is_empty() + } + + pub fn size(&self) -> usize { + self.manifest.size() + } + + pub fn dna_size(&self) -> Result { + let multiselection = MultiSelection::from_moltypes(vec!["dna"])?; + let selected_manifest = self.manifest.clone().select(&multiselection)?; + + Ok(selected_manifest.records.len()) + } + + pub fn protein_size(&self) -> Result { + let multiselection = MultiSelection::from_moltypes(vec!["protein"])?; + let selected_manifest = self.manifest.clone().select(&multiselection)?; + + Ok(selected_manifest.records.len()) + } + + pub fn anyprotein_size(&self) -> Result { + let multiselection = MultiSelection::from_moltypes(vec!["protein", "dayhoff", "hp"])?; + let selected_manifest = self.manifest.clone().select(&multiselection)?; + + Ok(selected_manifest.records.len()) + } + + pub fn parse_ksize(value: &str) -> Result { + value + .parse::() + .map_err(|_| format!("cannot parse k='{}' as a valid integer", value)) + } + + pub fn parse_int_once( + value: &str, + field: &str, + current: &mut Option, + ) -> Result<(), String> + where + T: FromStr + Display + Copy, + { + let parsed_value = value + .parse::() + .map_err(|_| format!("cannot parse {}='{}' as a valid integer", field, value))?; + + // Check for conflicts; we don't allow multiple values for the same field. + if let Some(old_value) = *current { + return Err(format!( + "Conflicting values for '{}': {} and {}", + field, old_value, parsed_value + )); + } + + *current = Some(parsed_value); + Ok(()) + } + + pub fn parse_moltype(item: &str, current: &mut Option) -> Result { + let new_moltype = match item { + "protein" | "dna" | "dayhoff" | "hp" => item.to_string(), + _ => return Err(format!("unknown moltype '{}'", item)), + }; + + // Check for conflicts and update the moltype. + if let Some(existing) = current { + if *existing != new_moltype { + return Err(format!( + "Conflicting moltype settings in param string: '{}' and '{}'", + existing, new_moltype + )); + } + } + + *current = Some(new_moltype.clone()); + Ok(new_moltype) + } + + pub fn parse_abundance(item: &str, current: &mut Option) -> Result<(), String> { + let new_abundance = item == "abund"; + + if let Some(existing) = *current { + if existing != new_abundance { + return Err(format!( + "Conflicting abundance settings in param string: '{}'", + item + )); + } + } + + *current = Some(new_abundance); + Ok(()) + } + + pub fn parse_params(p_str: &str) -> Result<(BuildRecord, Vec), String> { + let mut ksizes = Vec::new(); + let mut moltype: Option = None; + let mut track_abundance: Option = None; + let mut num: Option = None; + let mut scaled: Option = None; + let mut seed: Option = None; + + for item in p_str.split(',') { + match item { + _ if item.starts_with("k=") => { + ksizes.push(Self::parse_ksize(&item[2..])?); + } + "abund" | "noabund" => { + Self::parse_abundance(item, &mut track_abundance)?; + } + "protein" | "dna" | "DNA" | "dayhoff" | "hp" => { + Self::parse_moltype(item, &mut moltype)?; + } + _ if item.starts_with("num=") => { + Self::parse_int_once(&item[4..], "num", &mut num)?; + } + _ if item.starts_with("scaled=") => { + Self::parse_int_once(&item[7..], "scaled", &mut scaled)?; + } + _ if item.starts_with("seed=") => { + Self::parse_int_once(&item[5..], "seed", &mut seed)?; + } + _ => return Err(format!("unknown component '{}' in params string", item)), + } + } + + // Create a moltype-specific default BuildRecord. + let mut base_record = match moltype.as_deref() { + Some("dna") => BuildRecord::default_dna(), + Some("DNA") => BuildRecord::default_dna(), + Some("protein") => BuildRecord::default_protein(), + Some("dayhoff") => BuildRecord::default_dayhoff(), + Some("hp") => BuildRecord::default_hp(), + _ => BuildRecord::default_dna(), // no moltype --> assume DNA + }; - fn into_iter(self) -> Self::IntoIter { - self.records.iter_mut() - } -} + // Apply parsed values + if let Some(track_abund) = track_abundance { + base_record.with_abundance = track_abund; + } + if let Some(n) = num { + base_record.num = n; + } + if let Some(s) = scaled { + base_record.scaled = s; + } + if let Some(s) = seed { + base_record.seed = s; + } -#[derive(Debug, Default, Clone)] -pub struct BuildCollection { - pub manifest: BuildManifest, - pub sigs: Vec, -} + // Use the default ksize if none were specified. + if ksizes.is_empty() { + ksizes.push(base_record.ksize); + } -impl BuildCollection { - pub fn new() -> Self { - BuildCollection { - manifest: BuildManifest::new(), - sigs: Vec::new(), + // Ensure that num and scaled are mutually exclusive unless num is 0. + if let (Some(n), Some(_)) = (num, scaled) { + if n != 0 { + return Err("Cannot specify both 'num' (non-zero) and 'scaled' in the same parameter string".to_string()); + } } - } - pub fn is_empty(&self) -> bool { - self.manifest.is_empty() && self.sigs.is_empty() + Ok((base_record, ksizes)) } - pub fn size(&self) -> usize { - self.manifest.size() - } + pub fn from_param_str(params_str: &str) -> Result { + if params_str.trim().is_empty() { + return Err("Parameter string cannot be empty.".to_string()); + } - pub fn from_buildparams(params: &[BuildParams], input_moltype: &str) -> Self { - let mut collection = BuildCollection::new(); + let mut coll = BuildCollection::new(); + let mut seen_records = HashSet::new(); - for param in params.iter().cloned() { - collection.add_template_sig(param, input_moltype); - } + for p_str in params_str.split('_') { + // Use `parse_params` to get the base record and ksizes. + let (base_record, ksizes) = Self::parse_params(p_str)?; - collection + // Iterate over each ksize and add a signature to the collection. + for k in ksizes { + let mut record = base_record.clone(); + record.ksize = k; + + // Check if the record is already in the set. + if seen_records.insert(record.clone()) { + // Add the record and its associated signature to the collection. + // coll.add_template_sig_from_record(&record, &record.moltype); + coll.add_template_sig_from_record(&record); + } + } + } + Ok(coll) } - pub fn from_buildparams_set(params_set: &BuildParamsSet, input_moltype: &str) -> Self { + pub fn from_manifest(manifest: &BuildManifest) -> Self { let mut collection = BuildCollection::new(); - for param in params_set.iter().cloned() { - collection.add_template_sig(param, input_moltype); + // Iterate over each `BuildRecord` in the provided `BuildManifest`. + for record in &manifest.records { + // Add a signature to the collection using the `BuildRecord` and `input_moltype`. + collection.add_template_sig_from_record(record); } collection } - pub fn add_template_sig(&mut self, param: BuildParams, input_moltype: &str) { - // Check the input_moltype against Params to decide if this should be added - match input_moltype.to_lowercase().as_str() { - "dna" if param.moltype != HashFunctions::Murmur64Dna => return, // Skip if it's not DNA - "protein" - if param.moltype != HashFunctions::Murmur64Protein - && param.moltype != HashFunctions::Murmur64Dayhoff - && param.moltype != HashFunctions::Murmur64Hp => - { - return - } // Skip if not a protein type - _ => (), - } - - // Adjust ksize for protein, dayhoff, or hp, which typically require tripling the k-mer size - let adjusted_ksize = match param.moltype { - HashFunctions::Murmur64Protein - | HashFunctions::Murmur64Dayhoff - | HashFunctions::Murmur64Hp => param.ksize * 3, - _ => param.ksize, + pub fn add_template_sig_from_record(&mut self, record: &BuildRecord) { + // Adjust ksize for protein, dayhoff, or hp, which require tripling the k-mer size. + let adjusted_ksize = match record.moltype.as_str() { + "protein" | "dayhoff" | "hp" => record.ksize * 3, + _ => record.ksize, }; - // Construct ComputeParameters + // Construct ComputeParameters. let cp = ComputeParameters::builder() .ksizes(vec![adjusted_ksize]) - .scaled(param.scaled) - .protein(param.moltype == HashFunctions::Murmur64Protein) - .dna(param.moltype == HashFunctions::Murmur64Dna) - .dayhoff(param.moltype == HashFunctions::Murmur64Dayhoff) - .hp(param.moltype == HashFunctions::Murmur64Hp) - .num_hashes(param.num) - .track_abundance(param.track_abundance) + .scaled(record.scaled) + .protein(record.moltype == "protein") + .dna(record.moltype == "DNA") + .dayhoff(record.moltype == "dayhoff") + .hp(record.moltype == "hp") + .num_hashes(record.num) + .track_abundance(record.with_abundance) .build(); - // Create a Signature from the ComputeParameters + // Create a Signature from the ComputeParameters. let sig = Signature::from_params(&cp); - // Create the BuildRecord using from_param - let template_record = BuildRecord::from_buildparams(¶m); + // Clone the `BuildRecord` and use it directly. + let template_record = record.clone(); - // Add the record and signature to the collection + // Add the record and signature to the collection. self.manifest.records.push(template_record); self.sigs.push(sig); } + pub fn filter_manifest(&mut self, other: &BuildManifest) { + self.manifest = self.manifest.filter_manifest(other) + } + + pub fn filter_by_manifest(&mut self, other: &BuildManifest) { + // Create a HashSet for efficient filtering based on the `BuildRecord`s in `other`. + let other_records: HashSet<_> = other.records.iter().collect(); + + // Retain only the records that are not in `other_records`, filtering in place. + let mut sig_index = 0; + self.manifest.records.retain(|record| { + let keep = !other_records.contains(record); + if !keep { + // Remove the corresponding signature at the same index. + self.sigs.remove(sig_index); + } else { + sig_index += 1; // Only increment if we keep the record and signature. + } + keep + }); + } + + // filter template signatures that had no sequence added + // suggested use right before writing signatures + pub fn filter_empty(&mut self) { + let mut sig_index = 0; + + self.manifest.records.retain(|record| { + // Keep only records where `sequence_added` is `true`. + let keep = record.sequence_added; + + if !keep { + // Remove the corresponding signature at the same index if the record is not kept. + self.sigs.remove(sig_index); + } else { + sig_index += 1; // Only increment if we keep the record and signature. + } + + keep + }); + } + pub fn filter(&mut self, params_set: &HashSet) { let mut index = 0; while index < self.manifest.records.len() { @@ -552,6 +698,14 @@ impl BuildCollection { } } + pub fn iter(&self) -> impl Iterator { + self.manifest.iter().enumerate().map(|(i, r)| (i as Idx, r)) + } + + pub fn record_for_dataset(&self, dataset_id: Idx) -> Result<&BuildRecord> { + Ok(&self.manifest[dataset_id as usize]) + } + pub fn sigs_iter_mut(&mut self) -> impl Iterator { self.sigs.iter_mut() } @@ -564,7 +718,7 @@ impl BuildCollection { pub fn build_sigs_from_data( &mut self, data: Vec, - input_moltype: &str, // (protein/dna); todo - use hashfns? + input_moltype: &str, name: String, filename: String, ) -> Result<()> { @@ -575,14 +729,22 @@ impl BuildCollection { // Iterate over FASTA records and add sequences/proteins to sigs while let Some(record) = fastx_reader.next() { let record = record.context("Failed to read record")?; - self.sigs_iter_mut().for_each(|sig| { - if input_moltype == "protein" { + self.iter_mut().for_each(|(rec, sig)| { + if input_moltype == "protein" + && (rec.moltype == "protein" || rec.moltype == "dayhoff" || rec.moltype == "hp") + { sig.add_protein(&record.seq()) .expect("Failed to add protein"); - } else { + if !rec.sequence_added { + rec.sequence_added = true + } + } else if input_moltype == "DNA" && rec.moltype == "DNA" { sig.add_sequence(&record.seq(), true) .expect("Failed to add sequence"); // if not force, panics with 'N' in dna sequence + if !rec.sequence_added { + rec.sequence_added = true + } } }); } @@ -603,14 +765,24 @@ impl BuildCollection { // Iterate over FASTA records and add sequences/proteins to sigs while let Some(record) = fastx_reader.next() { let record = record.context("Failed to read record")?; - self.sigs_iter_mut().for_each(|sig| { - if input_moltype == "protein" { + self.iter_mut().for_each(|(rec, sig)| { + if input_moltype == "protein" + && (rec.moltype() == HashFunctions::Murmur64Protein + || rec.moltype() == HashFunctions::Murmur64Dayhoff + || rec.moltype() == HashFunctions::Murmur64Hp) + { sig.add_protein(&record.seq()) .expect("Failed to add protein"); + if !rec.sequence_added { + rec.sequence_added = true + } } else { sig.add_sequence(&record.seq(), true) .expect("Failed to add sequence"); // if not force, panics with 'N' in dna sequence + if !rec.sequence_added { + rec.sequence_added = true + } } }); } @@ -627,14 +799,24 @@ impl BuildCollection { input_moltype: &str, // (protein/dna); todo - use hashfns? filename: String, ) -> Result<()> { - self.sigs_iter_mut().for_each(|sig| { - if input_moltype == "protein" { + self.iter_mut().for_each(|(rec, sig)| { + if input_moltype == "protein" + && (rec.moltype() == HashFunctions::Murmur64Protein + || rec.moltype() == HashFunctions::Murmur64Dayhoff + || rec.moltype() == HashFunctions::Murmur64Hp) + { sig.add_protein(&record.seq()) .expect("Failed to add protein"); + if !rec.sequence_added { + rec.sequence_added = true + } } else { sig.add_sequence(&record.seq(), true) .expect("Failed to add sequence"); // if not force, panics with 'N' in dna sequence + if !rec.sequence_added { + rec.sequence_added = true + } } }); let record_name = std::str::from_utf8(record.id()) @@ -649,22 +831,25 @@ impl BuildCollection { pub fn update_info(&mut self, name: String, filename: String) { // update the records to reflect information the signature; for (record, sig) in self.iter_mut() { - // update signature name, filename - sig.set_name(name.as_str()); - sig.set_filename(filename.as_str()); - - // update record: set name, filename, md5sum, n_hashes - record.set_name(Some(name.clone())); - record.set_filename(Some(filename.clone())); - record.set_md5(Some(sig.md5sum())); - record.set_md5short(Some(sig.md5sum()[0..8].into())); - record.set_n_hashes(Some(sig.size())); - - // note, this needs to be set when writing sigs - // record.set_internal_location("") + if record.sequence_added { + // update signature name, filename + sig.set_name(name.as_str()); + sig.set_filename(filename.as_str()); + + // update record: set name, filename, md5sum, n_hashes + record.set_name(Some(name.clone())); + record.set_filename(Some(filename.clone())); + record.set_md5(Some(sig.md5sum())); + record.set_md5short(Some(sig.md5sum()[0..8].into())); + record.set_n_hashes(Some(sig.size())); + + // note, this needs to be set when writing sigs (not here) + // record.set_internal_location("") + } } } + // to do -- use filter_empty to ensure we're not writing empty template sigs?? pub async fn async_write_sigs_to_zip( &mut self, // need mutable to update records zip_writer: &mut ZipFileWriter>, @@ -731,6 +916,40 @@ impl<'a> IntoIterator for &'a mut BuildCollection { } } +impl MultiSelect for BuildCollection { + // to do --> think through the best/most efficient way to do this + // in sourmash core, we don't need to select sigs themselves. Is this due to the way that Idx/Storage work? + fn select(mut self, multi_selection: &MultiSelection) -> Result { + // Collect indices while retaining matching records + let mut selected_indices = Vec::new(); + let mut current_index = 0; + + self.manifest.records.retain(|record| { + let keep = multi_selection + .selections + .iter() + .any(|selection| record.matches_selection(selection)); + + if keep { + selected_indices.push(current_index); // Collect the index of the retained record + } + + current_index += 1; // Move to the next index + keep // Retain the record if it matches the selection + }); + + // Retain corresponding signatures using the collected indices + let mut sig_index = 0; + self.sigs.retain(|_sig| { + let keep = selected_indices.contains(&sig_index); + sig_index += 1; + keep + }); + + Ok(self) + } +} + #[derive(Debug, Clone)] pub struct MultiBuildCollection { pub collections: Vec, @@ -756,289 +975,97 @@ impl MultiBuildCollection { mod tests { use super::*; - #[test] - fn test_buildparams_consistent_hashing() { - let params1 = BuildParams { - ksize: 31, - track_abundance: true, - ..Default::default() - }; - - let params2 = BuildParams { - ksize: 31, - track_abundance: true, - ..Default::default() - }; - - let hash1 = params1.calculate_hash(); - let hash2 = params2.calculate_hash(); - let hash3 = params2.calculate_hash(); - - // Check that the hash for two identical Params is the same - assert_eq!(hash1, hash2, "Hashes for identical Params should be equal"); - - assert_eq!( - hash2, hash3, - "Hashes for the same Params should be consistent across multiple calls" - ); - } - - #[test] - fn test_buildparams_hashing_different() { - let params1 = BuildParams { - ksize: 31, - ..Default::default() - }; - - let params2 = BuildParams { - ksize: 21, // Changed ksize - ..Default::default() - }; - - let params3 = BuildParams { - ksize: 31, - moltype: HashFunctions::Murmur64Protein, - ..Default::default() - }; - - let hash1 = params1.calculate_hash(); - let hash2 = params2.calculate_hash(); - let hash3 = params3.calculate_hash(); - - // Check that the hash for different Params is different - assert_ne!( - hash1, hash2, - "Hashes for different Params should not be equal" - ); - assert_ne!( - hash1, hash3, - "Hashes for different Params should not be equal" - ); - } - - #[test] - fn test_buildparams_generated_from_record() { - // load signature + build record - let mut filename = Utf8PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/GCA_000175535.1.sig.gz"); - let path = filename.clone(); - - let file = std::fs::File::open(filename).unwrap(); - let mut reader = std::io::BufReader::new(file); - let sigs = Signature::load_signatures( - &mut reader, - Some(31), - Some("DNA".try_into().unwrap()), - None, - ) - .unwrap(); - - assert_eq!(sigs.len(), 1); - - let sig = sigs.get(0).unwrap(); - let record = Record::from_sig(sig, path.as_str()); - - // create the expected Params based on the Record data - let expected_params = BuildParams { - ksize: 31, - track_abundance: true, - ..Default::default() - }; - - // // Generate the Params from the Record using the from_record method - let generated_params = BuildParams::from_record(&record[0]); - - // // Assert that the generated Params match the expected Params - assert_eq!( - generated_params, expected_params, - "Generated Params did not match the expected Params" - ); - - // // Calculate the hash for the expected Params - let expected_hash = expected_params.calculate_hash(); - - // // Calculate the hash for the generated Params - let generated_hash = generated_params.calculate_hash(); - - // // Assert that the hash for the generated Params matches the expected Params hash - assert_eq!( - generated_hash, expected_hash, - "Hash of generated Params did not match the hash of expected Params" - ); - } - - #[test] - fn test_filter_removes_matching_buildparams() { - let params1 = BuildParams { - ksize: 31, - track_abundance: true, - ..Default::default() - }; - - let params2 = BuildParams { - ksize: 21, - track_abundance: true, - ..Default::default() - }; - - let params3 = BuildParams { - ksize: 31, - scaled: 2000, - track_abundance: true, - ..Default::default() - }; - - let params_list = [params1.clone(), params2.clone(), params3.clone()]; - let mut build_collection = BuildCollection::from_buildparams(¶ms_list, "DNA"); - - let mut params_set = HashSet::new(); - params_set.insert(params1.calculate_hash()); - params_set.insert(params3.calculate_hash()); - - // Call the filter method - build_collection.filter(¶ms_set); - - // Check that the records and signatures with matching params are removed - assert_eq!( - build_collection.manifest.records.len(), - 1, - "Only one record should remain after filtering" - ); - assert_eq!( - build_collection.sigs.len(), - 1, - "Only one signature should remain after filtering" - ); - - // Check that the remaining record is the one with hashed_params = 456 - let h2 = params2.calculate_hash(); - assert_eq!( - build_collection.manifest.records[0].hashed_params, h2, - "The remaining record should have hashed_params {}", - h2 - ); - } - - #[test] - fn test_build_params_set_default() { - // Create a default BuildParamsSet. - let default_set = BuildParamsSet::default(); - - // Check that the set contains exactly one item. - assert_eq!( - default_set.size(), - 1, - "Expected default BuildParamsSet to contain one default BuildParams." - ); - - // Get the default parameter from the set. - let default_param = default_set.iter().next().unwrap(); - - // Verify that the default parameter has expected default values. - assert_eq!(default_param.ksize, 31, "Expected default ksize to be 31."); - assert_eq!( - default_param.track_abundance, false, - "Expected default track_abundance to be false." - ); - assert_eq!( - default_param.moltype, - HashFunctions::Murmur64Dna, - "Expected default moltype to be DNA." - ); - assert_eq!(default_param.num, 0, "Expected default num to be 0."); - assert_eq!( - default_param.scaled, 1000, - "Expected default scaled to be 1000." - ); - assert_eq!(default_param.seed, 42, "Expected default seed to be 42."); - } - #[test] fn test_valid_params_str() { let params_str = "k=31,abund,dna"; - let result = BuildParamsSet::from_params_str(params_str.to_string()); + let result = BuildCollection::parse_params(params_str); assert!( result.is_ok(), "Expected 'k=31,abund,dna' to be valid, but got an error: {:?}", result ); - let params_set = result.unwrap(); - // Ensure the BuildParamsSet contains the expected number of parameters. - assert_eq!( - params_set.iter().count(), - 1, - "Expected 1 BuildParams in the set" - ); + let (record, ksizes) = result.unwrap(); - // Verify that the BuildParams has the correct settings. - let param = params_set.iter().next().unwrap(); - assert_eq!(param.ksize, 31); - assert_eq!(param.track_abundance, true); - assert_eq!(param.moltype, HashFunctions::Murmur64Dna); + // Verify that the Record, ksizes have the correct settings. + assert_eq!(record.moltype, "DNA"); + assert_eq!(record.with_abundance, true); + assert_eq!(ksizes, vec![31]); + assert_eq!(record.scaled, 1000, "Expected default scaled value of 1000"); + assert_eq!(record.num, 0, "Expected default num value of 0"); } #[test] - fn test_multiple_valid_params_str() { + fn test_from_param_str() { let params_str = "k=31,abund,dna_k=21,k=31,k=51,abund_k=10,protein"; - let result = BuildParamsSet::from_params_str(params_str.to_string()); + let coll_result = BuildCollection::from_param_str(params_str); assert!( - result.is_ok(), - "Param str {} is valid, but got an error: {:?}", + coll_result.is_ok(), + "Param str '{}' is valid, but got an error: {:?}", params_str, - result + coll_result ); - let params_set: BuildParamsSet = result.unwrap(); - // Ensure the BuildParamsSet contains the expected number of parameters. - // note that k=31,dna,abund is in two diff param strs, should only show up once. - assert_eq!(params_set.size(), 4, "Expected 4 BuildParams in the set"); - // Define the expected BuildParams for comparison. - let expected_params = vec![ - { - let mut param = BuildParams::default(); - param.ksize = 31; - param.track_abundance = true; - param - }, - { - let mut param = BuildParams::default(); - param.ksize = 21; - param.track_abundance = true; - param + let coll = coll_result.unwrap(); + + // Ensure the BuildCollection contains the expected number of records. + // Note that "k=31,abund,dna" appears in two different parameter strings, so it should only appear once. + assert_eq!( + coll.manifest.records.len(), + 4, + "Expected 4 unique BuildRecords in the collection, but found {}", + coll.manifest.records.len() + ); + + // Define the expected BuildRecords for comparison. + let expected_records = vec![ + BuildRecord { + ksize: 31, + moltype: "DNA".to_string(), + with_abundance: true, + ..Default::default() }, - { - let mut param = BuildParams::default(); - param.ksize = 51; - param.track_abundance = true; - param + BuildRecord { + ksize: 21, + moltype: "DNA".to_string(), + with_abundance: true, + ..Default::default() }, - { - let mut param = BuildParams::default(); - param.ksize = 10; - param.moltype = HashFunctions::Murmur64Protein; - param + BuildRecord { + ksize: 51, + moltype: "DNA".to_string(), + with_abundance: true, + ..Default::default() }, + BuildRecord::default_protein(), ]; - // Check that each expected BuildParams is present in the params_set. - for expected_param in expected_params { + // Verify that each expected BuildRecord is present in the collection. + for expected_record in expected_records { assert!( - params_set.get_params().contains(&expected_param), - "Expected BuildParams with ksize: {}, track_abundance: {}, moltype: {:?} not found in the set", - expected_param.ksize, - expected_param.track_abundance, - expected_param.moltype - ); + coll.manifest.records.contains(&expected_record), + "Expected BuildRecord with ksize: {}, moltype: {}, with_abundance: {} not found in the collection", + expected_record.ksize, + expected_record.moltype, + expected_record.with_abundance + ); } + + // Optionally, check that the corresponding signatures are present. + assert_eq!( + coll.sigs.len(), + 4, + "Expected 4 Signatures in the collection, but found {}", + coll.sigs.len() + ); } #[test] fn test_invalid_params_str_conflicting_moltypes() { let params_str = "k=31,abund,dna,protein"; - let result = BuildParamsSet::from_params_str(params_str.to_string()); + let result = BuildCollection::from_param_str(params_str); assert!( result.is_err(), @@ -1058,7 +1085,7 @@ mod tests { #[test] fn test_unknown_component_error() { // Test for an unknown component that should trigger an error. - let result = BuildParamsSet::from_params_str("k=31,notaparam".to_string()); + let result = BuildCollection::from_param_str("k=31,notaparam"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1069,7 +1096,7 @@ mod tests { #[test] fn test_unknown_component_error2() { // Test a common param string error (k=31,51 compared with valid k=31,k=51) - let result = BuildParamsSet::from_params_str("k=31,51,abund".to_string()); + let result = BuildCollection::from_param_str("k=31,51,abund"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1080,7 +1107,7 @@ mod tests { #[test] fn test_conflicting_num_and_scaled() { // Test for specifying both num and scaled, which should result in an error. - let result = BuildParamsSet::from_params_str("k=31,num=10,scaled=1000".to_string()); + let result = BuildCollection::from_param_str("k=31,num=10,scaled=1000"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1091,7 +1118,7 @@ mod tests { #[test] fn test_conflicting_abundance() { // Test for providing conflicting abundance settings, which should result in an error. - let result = BuildParamsSet::from_params_str("k=31,abund,noabund".to_string()); + let result = BuildCollection::from_param_str("k=31,abund,noabund"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1102,7 +1129,7 @@ mod tests { #[test] fn test_invalid_ksize_format() { // Test for an invalid ksize format that should trigger an error. - let result = BuildParamsSet::from_params_str("k=abc".to_string()); + let result = BuildCollection::from_param_str("k=abc"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1113,7 +1140,7 @@ mod tests { #[test] fn test_invalid_num_format() { // Test for an invalid number format that should trigger an error. - let result = BuildParamsSet::from_params_str("k=31,num=abc".to_string()); + let result = BuildCollection::from_param_str("k=31,num=abc"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1124,7 +1151,7 @@ mod tests { #[test] fn test_invalid_scaled_format() { // Test for an invalid scaled format that should trigger an error. - let result = BuildParamsSet::from_params_str("k=31,scaled=abc".to_string()); + let result = BuildCollection::from_param_str("k=31,scaled=abc"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1135,7 +1162,7 @@ mod tests { #[test] fn test_invalid_seed_format() { // Test for an invalid seed format that should trigger an error. - let result = BuildParamsSet::from_params_str("k=31,seed=abc".to_string()); + let result = BuildCollection::from_param_str("k=31,seed=abc"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1146,7 +1173,7 @@ mod tests { #[test] fn test_repeated_values() { // repeated scaled - let result = BuildParamsSet::from_params_str("k=31,scaled=1,scaled=1000".to_string()); + let result = BuildCollection::from_param_str("k=31,scaled=1,scaled=1000"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1154,7 +1181,7 @@ mod tests { ); // repeated num - let result = BuildParamsSet::from_params_str("k=31,num=1,num=1000".to_string()); + let result = BuildCollection::from_param_str("k=31,num=1,num=1000"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1162,7 +1189,7 @@ mod tests { ); // repeated seed - let result = BuildParamsSet::from_params_str("k=31,seed=1,seed=42".to_string()); + let result = BuildCollection::from_param_str("k=31,seed=1,seed=42"); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!( result.unwrap_err(), @@ -1173,99 +1200,201 @@ mod tests { #[test] fn test_missing_ksize() { // Test for a missing ksize, using default should not result in an error. - let result = BuildParamsSet::from_params_str("abund".to_string()); + let result = BuildCollection::from_param_str("abund"); assert!(result.is_ok(), "Expected Ok but got an error."); } #[test] fn test_repeated_ksize() { // Repeated ksize settings should not trigger an error since it is valid to have multiple ksizes. - let result = BuildParamsSet::from_params_str("k=31,k=21".to_string()); + let result = BuildCollection::from_param_str("k=31,k=21"); assert!(result.is_ok(), "Expected Ok but got an error."); } #[test] fn test_empty_string() { // Test for an empty parameter string, which should now result in an error. - let result = BuildParamsSet::from_params_str("".to_string()); + let result = BuildCollection::from_param_str(""); assert!(result.is_err(), "Expected an error but got Ok."); assert_eq!(result.unwrap_err(), "Parameter string cannot be empty."); } #[test] - fn test_from_buildparams_abundance() { - let mut params = BuildParams::default(); - params.track_abundance = true; - - // Create a BuildRecord using from_buildparams. - let record = BuildRecord::from_buildparams(¶ms); - - // Check thqat all fields are set correctly. - assert_eq!(record.ksize, 31, "Expected ksize to be 31."); - assert_eq!(record.moltype, "DNA", "Expected moltype to be 'DNA'."); - assert_eq!(record.scaled, 1000, "Expected scaled to be 1000."); - assert_eq!(record.num, 0, "Expected num to be 0."); - assert!(record.with_abundance, "Expected with_abundance to be true."); - assert_eq!( - record.hashed_params, - params.calculate_hash(), - "Expected the hashed_params to match the calculated hash." - ); + fn test_filter_by_manifest_with_matching_records() { + // Create a BuildCollection with some records and signatures. + + let rec1 = BuildRecord::default_dna(); + let rec2 = BuildRecord { + ksize: 21, + moltype: "DNA".to_string(), + scaled: 1000, + ..Default::default() + }; + let rec3 = BuildRecord { + ksize: 31, + moltype: "DNA".to_string(), + scaled: 1000, + with_abundance: true, + ..Default::default() + }; + + let bmanifest = BuildManifest { + records: vec![rec1.clone(), rec2.clone(), rec3.clone()], + }; + // let mut dna_build_collection = BuildCollection::from_manifest(&bmanifest, "DNA"); + let mut dna_build_collection = BuildCollection::from_manifest(&bmanifest); + + // Create a BuildManifest with records to filter out. + let filter_manifest = BuildManifest { + records: vec![rec1], + }; + + // Apply the filter. + dna_build_collection.filter_by_manifest(&filter_manifest); + + // check that the default DNA sig remains + assert_eq!(dna_build_collection.manifest.size(), 2); + + let remaining_records = &dna_build_collection.manifest.records; + + assert!(remaining_records.contains(&rec2)); + assert!(remaining_records.contains(&rec3)); } #[test] - fn test_from_buildparams_protein() { - let mut params = BuildParams::default(); - params.ksize = 10; - params.scaled = 200; - params.moltype = HashFunctions::Murmur64Protein; - - let record = BuildRecord::from_buildparams(¶ms); - - // Check that all fields are set correctly. - assert_eq!(record.ksize, 10, "Expected ksize to be 10."); - assert_eq!(record.moltype, "protein", "Expected moltype to be protein."); - assert_eq!(record.scaled, 200, "Expected scaled to be 200."); - assert_eq!( - record.hashed_params, - params.calculate_hash(), - "Expected the hashed_params to match the calculated hash." - ); + fn test_add_template_sig_from_record() { + // Create a BuildCollection. + let mut build_collection = BuildCollection::new(); + + // Create a DNA BuildRecord. + let dna_record = BuildRecord { + ksize: 31, + moltype: "DNA".to_string(), + scaled: 1000, + with_abundance: true, + ..Default::default() + }; + + // Add the DNA record to the collection with a matching moltype. + // build_collection.add_template_sig_from_record(&dna_record, "DNA"); + build_collection.add_template_sig_from_record(&dna_record); + + // Verify that the record was added. + assert_eq!(build_collection.manifest.records.len(), 1); + assert_eq!(build_collection.sigs.len(), 1); + + let added_record = &build_collection.manifest.records[0]; + assert_eq!(added_record.moltype, "DNA"); + assert_eq!(added_record.ksize, 31); + assert_eq!(added_record.with_abundance, true); + + // Create a protein BuildRecord. + let protein_record = BuildRecord { + ksize: 10, + moltype: "protein".to_string(), + scaled: 200, + with_abundance: false, + ..Default::default() + }; + + // Add the protein record to the collection with a matching moltype. + // build_collection.add_template_sig_from_record(&protein_record, "protein"); + build_collection.add_template_sig_from_record(&protein_record); + + // Verify that the protein record was added and ksize adjusted. + assert_eq!(build_collection.manifest.records.len(), 2); + assert_eq!(build_collection.sigs.len(), 2); + + let added_protein_record = &build_collection.manifest.records[1]; + assert_eq!(added_protein_record.moltype, "protein"); + assert_eq!(added_protein_record.ksize, 10); + assert_eq!(added_protein_record.with_abundance, false); + + // Create a BuildRecord with a non-matching moltype. + let non_matching_record = BuildRecord { + ksize: 10, + moltype: "dayhoff".to_string(), + scaled: 200, + with_abundance: true, + ..Default::default() + }; + + // Attempt to add the non-matching record with "DNA" as input moltype. + // this is because we currently don't allow translation + // build_collection.add_template_sig_from_record(&non_matching_record, "DNA"); + + // Verify that the non-matching record was not added. + // assert_eq!(build_collection.manifest.records.len(), 2); + // assert_eq!(build_collection.sigs.len(), 2); + + // Add the same non-matching record with a matching input moltype. + build_collection.add_template_sig_from_record(&non_matching_record); + + // Verify that the record was added. + assert_eq!(build_collection.manifest.records.len(), 3); + assert_eq!(build_collection.sigs.len(), 3); + + let added_dayhoff_record = &build_collection.manifest.records[2]; + assert_eq!(added_dayhoff_record.moltype, "dayhoff"); + assert_eq!(added_dayhoff_record.ksize, 10); + assert_eq!(added_dayhoff_record.with_abundance, true); } #[test] - fn test_from_buildparams_dayhoff() { - let mut params = BuildParams::default(); - params.ksize = 10; - params.moltype = HashFunctions::Murmur64Dayhoff; + fn test_filter_empty() { + // Create a parameter string that generates BuildRecords with different `sequence_added` values. + let params_str = "k=31,abund,dna_k=21,protein_k=10,abund"; - let record = BuildRecord::from_buildparams(¶ms); + // Use `from_param_str` to build a `BuildCollection`. + let mut build_collection = BuildCollection::from_param_str(params_str) + .expect("Failed to build BuildCollection from params_str"); - assert_eq!(record.ksize, 10, "Expected ksize to be 10."); - assert_eq!(record.moltype, "dayhoff", "Expected moltype to be dayhoff."); - // didn't change default scaled here, so should still be 1000 - assert_eq!(record.scaled, 1000, "Expected scaled to be 1000."); + // Manually set `sequence_added` for each record to simulate different conditions. + build_collection.manifest.records[0].sequence_added = true; // Keep this record. + build_collection.manifest.records[1].sequence_added = false; // This record should be removed. + build_collection.manifest.records[2].sequence_added = true; // Keep this record. + + // Check initial sizes before filtering. + assert_eq!( + build_collection.manifest.records.len(), + 3, + "Expected 3 records before filtering, but found {}", + build_collection.manifest.records.len() + ); assert_eq!( - record.hashed_params, - params.calculate_hash(), - "Expected the hashed_params to match the calculated hash." + build_collection.sigs.len(), + 3, + "Expected 3 signatures before filtering, but found {}", + build_collection.sigs.len() ); - } - #[test] - fn test_from_buildparams_hp() { - let mut params = BuildParams::default(); - params.ksize = 10; - params.moltype = HashFunctions::Murmur64Hp; + // Apply the `filter_empty` method. + build_collection.filter_empty(); - let record = BuildRecord::from_buildparams(¶ms); + // After filtering, only the records with `sequence_added == true` should remain. + assert_eq!( + build_collection.manifest.records.len(), + 2, + "Expected 2 records after filtering, but found {}", + build_collection.manifest.records.len() + ); - assert_eq!(record.ksize, 10, "Expected ksize to be 10."); - assert_eq!(record.moltype, "hp", "Expected moltype to be hp."); + // Check that the signatures also match the remaining records. assert_eq!( - record.hashed_params, - params.calculate_hash(), - "Expected the hashed_params to match the calculated hash." + build_collection.sigs.len(), + 2, + "Expected 2 signatures after filtering, but found {}", + build_collection.sigs.len() + ); + + // Verify that the remaining records have `sequence_added == true`. + assert!( + build_collection + .manifest + .records + .iter() + .all(|rec| rec.sequence_added), + "All remaining records should have `sequence_added == true`" ); } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 91b0ae5..041f8be 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -2,11 +2,11 @@ use anyhow::{anyhow, Result}; use reqwest::Url; use sourmash::collection::Collection; use std::collections::HashMap; -use std::collections::HashSet; +// use std::collections::HashSet; use std::fmt; pub mod buildutils; -use crate::utils::buildutils::BuildParams; +use crate::utils::buildutils::{BuildManifest, BuildRecord}; #[derive(Clone)] pub enum InputMolType { @@ -289,9 +289,8 @@ impl MultiCollection { self.collections.is_empty() } - pub fn buildparams_hashmap(&self) -> HashMap> { - let mut name_params_map = HashMap::new(); - + pub fn build_recordsmap(&self) -> HashMap { + let mut records_map = HashMap::new(); // Iterate over all collections in MultiCollection for collection in &self.collections { // Iterate over all records in the current collection @@ -299,79 +298,18 @@ impl MultiCollection { // Get the record's name or fasta filename let record_name = record.name().clone(); - // Calculate the hash of the Params for the current record - let params_hash = BuildParams::from_record(record).calculate_hash(); + // Create template buildrecord from this record + let build_record = BuildRecord::from_record(record); // If the name is already in the HashMap, extend the existing HashSet - // Otherwise, create a new HashSet and insert the hashed Params - name_params_map + // Otherwise, create a new BuildManifest and insert the BuildRecord + records_map .entry(record_name) - .or_insert_with(HashSet::new) // Create a new HashSet if the key doesn't exist - .insert(params_hash); // Insert the hashed Params into the HashSet + .or_insert_with(BuildManifest::default) // Create a new HashSet if the key doesn't exist + .add_record(build_record); // add buildrecord to buildmanifest } } - name_params_map - } -} - -#[cfg(test)] -mod tests { - use super::*; - use camino::Utf8PathBuf; - #[test] - fn test_buildparams_hashmap() { - // read in zipfiles to build a MultiCollection - // load signature + build record - let mut filename = Utf8PathBuf::from(env!("CARGO_MANIFEST_DIR")); - filename.push("tests/test-data/GCA_000961135.2.sig.zip"); - let path = filename.clone(); - - let mut collections = Vec::new(); - let coll: Collection = Collection::from_zipfile(&path).unwrap(); - collections.push(coll); - let mc = MultiCollection::new(collections); - - // Call build_params_hashmap - let name_params_map = mc.buildparams_hashmap(); - - // Check that the HashMap contains the correct names - assert_eq!( - name_params_map.len(), - 1, - "There should be 1 unique names in the map" - ); - - let mut hashed_params = Vec::new(); - for (name, params_set) in name_params_map.iter() { - eprintln!("Name: {}", name); - for param_hash in params_set { - eprintln!(" Param Hash: {}", param_hash); - hashed_params.push(param_hash); - } - } - let expected_params1 = BuildParams { - ksize: 31, - track_abundance: true, - ..Default::default() - }; - - let expected_params2 = BuildParams { - ksize: 21, - track_abundance: true, - ..Default::default() - }; - - let expected_hash1 = expected_params1.calculate_hash(); - let expected_hash2 = expected_params2.calculate_hash(); - - assert!( - hashed_params.contains(&&expected_hash1), - "Expected hash1 should be in the hashed_params" - ); - assert!( - hashed_params.contains(&&expected_hash2), - "Expected hash2 should be in the hashed_params" - ); + records_map } } diff --git a/tests/test_gbsketch.py b/tests/test_gbsketch.py index d72ed04..9a189f5 100644 --- a/tests/test_gbsketch.py +++ b/tests/test_gbsketch.py @@ -717,7 +717,7 @@ def test_gbsketch_simple_batch_restart(runtmp, capfd): assert len(sigs) == 2 for sig in sigs: assert sig.name == ss2.name - assert ss2.md5sum() in [ss2.md5sum(), ss3.md5sum()] + assert sig.md5sum() in [ss2.md5sum(), ss3.md5sum()] # # these were created with gbsketch expected_siginfo = { @@ -835,7 +835,7 @@ def test_gbsketch_bad_param_str(runtmp, capfd): captured = capfd.readouterr() print(captured) - assert "Failed to parse params string: Conflicting moltype settings in param string: 'DNA' and 'protein'" in captured.err + assert "Failed to parse params string: Conflicting moltype settings in param string: 'dna' and 'protein'" in captured.err def test_gbsketch_overwrite(runtmp, capfd):