From e1e9506dfc101c1dbe637d8b83b0f3e73f814a16 Mon Sep 17 00:00:00 2001 From: "N. Tessa Pierce-Ward" Date: Mon, 30 Sep 2024 10:54:38 -0700 Subject: [PATCH] use MultiBuildCollection instead of allowing extend of BuildCollection --- src/directsketch.rs | 64 +++++++++++++++++++++++---------------------- src/utils.rs | 39 +++++++++++++++------------ 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/src/directsketch.rs b/src/directsketch.rs index a9c3e5b..bb4eb1a 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -19,7 +19,7 @@ use sourmash::signature::Signature; use crate::utils::{ load_accession_info, load_gbassembly_info, parse_params_str, AccessionData, BuildCollection, - BuildManifest, GBAssemblyData, GenBankFileType, InputMolType, + BuildManifest, GBAssemblyData, GenBankFileType, InputMolType, MultiBuildCollection, }; use reqwest::Url; @@ -218,9 +218,9 @@ async fn dl_sketch_assembly_accession( genomes_only: bool, proteomes_only: bool, download_only: bool, -) -> Result<(BuildCollection, Vec)> { +) -> Result<(MultiBuildCollection, Vec)> { let retry_count = retry.unwrap_or(3); // Default retry count - let mut sig_collection = BuildCollection::new(); + let mut built_sigs = MultiBuildCollection::new(); let mut failed = Vec::::new(); let name = accinfo.name; @@ -255,7 +255,7 @@ async fn dl_sketch_assembly_accession( failed.push(failed_download_protein); } - return Ok((sig_collection, failed)); + return Ok((built_sigs, failed)); } }; let md5sum_url = GenBankFileType::Checksum.url(&base_url, &full_name); @@ -311,7 +311,7 @@ async fn dl_sketch_assembly_accession( match file_type { GenBankFileType::Genomic => { dna_sigs.build_sigs_from_data(data, "dna", name.clone(), file_name.clone())?; - sig_collection.extend_by_drain(dna_sigs); + built_sigs.add_collection(dna_sigs); } GenBankFileType::Protein => { prot_sigs.build_sigs_from_data( @@ -320,14 +320,14 @@ async fn dl_sketch_assembly_accession( name.clone(), file_name.clone(), )?; - sig_collection.extend_by_drain(prot_sigs); + built_sigs.add_collection(prot_sigs); } _ => {} // Do nothing for other file types }; } } - Ok((sig_collection, failed)) + Ok((built_sigs, failed)) } #[allow(clippy::too_many_arguments)] @@ -342,9 +342,9 @@ async fn dl_sketch_url( _genomes_only: bool, _proteomes_only: bool, download_only: bool, -) -> Result<(BuildCollection, Vec)> { +) -> Result<(MultiBuildCollection, Vec)> { let retry_count = retry.unwrap_or(3); // Default retry count - let mut sig_collection = BuildCollection::new(); + let mut built_sigs = MultiBuildCollection::new(); let mut failed = Vec::::new(); let name = accinfo.name; @@ -375,7 +375,7 @@ async fn dl_sketch_url( name.clone(), filename.clone(), )?; - sig_collection.extend_by_drain(dna_sigs); + built_sigs.add_collection(dna_sigs); } InputMolType::Protein => { prot_sigs.build_sigs_from_data( @@ -384,7 +384,7 @@ async fn dl_sketch_url( name.clone(), filename.clone(), )?; - sig_collection.extend_by_drain(prot_sigs); + built_sigs.add_collection(prot_sigs); } }; } @@ -402,7 +402,7 @@ async fn dl_sketch_url( } } - Ok((sig_collection, failed)) + Ok((built_sigs, failed)) } /// create zip file depending on batch size and index. @@ -417,7 +417,7 @@ async fn create_or_get_zip_file( } pub fn zipwriter_handle( - mut recv_sigs: tokio::sync::mpsc::Receiver, + mut recv_sigs: tokio::sync::mpsc::Receiver, output_sigs: Option, error_sender: tokio::sync::mpsc::Sender, ) -> tokio::task::JoinHandle<()> { @@ -439,26 +439,28 @@ pub fn zipwriter_handle( } }; - while let Some(mut sigcoll) = recv_sigs.recv().await { + while let Some(mut multibuildcoll) = recv_sigs.recv().await { // write all sigs from sigcoll. Note that this method updates each record's internal location - match sigcoll - .async_write_sigs_to_zip(&mut zip_writer, &mut md5sum_occurrences) - .await - { - Ok(_) => { - file_count += sigcoll.size(); - wrote_sigs = true; - } - Err(e) => { - let error = e.context("Error processing signature"); - if error_sender.send(error).await.is_err() { - return; + for sigcoll in &mut multibuildcoll.collections { + match sigcoll + .async_write_sigs_to_zip(&mut zip_writer, &mut md5sum_occurrences) + .await + { + Ok(_) => { + file_count += sigcoll.size(); + wrote_sigs = true; + } + Err(e) => { + let error = e.context("Error processing signature"); + if error_sender.send(error).await.is_err() { + return; + } } } + // add all records from sigcoll manifest + zip_manifest.extend_from_manifest(&sigcoll.manifest); + file_count += sigcoll.size(); } - // add all records from sigcoll manifest - zip_manifest.extend_from_manifest(&sigcoll.manifest); - file_count += sigcoll.size(); } if file_count > 0 { @@ -595,7 +597,7 @@ pub async fn gbsketch( create_dir_all(&download_path)?; } // create channels. buffer size here is 4 b/c we can do 3 downloads simultaneously - let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::(4); + let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::(4); let (send_failed, recv_failed) = tokio::sync::mpsc::channel::(4); // Error channel for handling task errors let (error_sender, error_receiver) = tokio::sync::mpsc::channel::(1); @@ -772,7 +774,7 @@ pub async fn urlsketch( } // create channels. buffer size here is 4 b/c we can do 3 downloads simultaneously - let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::(4); + let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::(4); let (send_failed, recv_failed) = tokio::sync::mpsc::channel::(4); // Error channel for handling task errors let (error_sender, error_receiver) = tokio::sync::mpsc::channel::(1); diff --git a/src/utils.rs b/src/utils.rs index 5482054..a4ff045 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -484,7 +484,7 @@ impl BuildManifest { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct BuildCollection { pub manifest: BuildManifest, pub sigs: Vec, @@ -553,20 +553,6 @@ impl BuildCollection { self.sigs.push(sig); } - pub fn extend(&mut self, other: BuildCollection) { - // Extend the manifest and signatures from another BuildCollection - self.manifest.records.extend(other.manifest.records); - self.sigs.extend(other.sigs); - } - - pub fn extend_by_drain(&mut self, other: &mut BuildCollection) { - // Extend the manifest and signatures by draining from another BuildCollection - self.manifest - .records - .extend(other.manifest.records.drain(..)); - self.sigs.extend(other.sigs.drain(..)); - } - pub fn filter(&mut self, params_set: &HashSet) { let mut index = 0; while index < self.manifest.records.len() { @@ -690,7 +676,7 @@ impl BuildCollection { record.set_md5short(Some(sig.md5sum()[0..8].into())); record.set_n_hashes(Some(sig.size())); - // what to set this to? + // note, this needs to be set when writing sigs // record.set_internal_location("") } } @@ -761,6 +747,27 @@ impl<'a> IntoIterator for &'a mut BuildCollection { } } +#[derive(Debug, Clone)] +pub struct MultiBuildCollection { + pub collections: Vec, +} + +impl MultiBuildCollection { + pub fn new() -> Self { + MultiBuildCollection { + collections: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.collections.is_empty() + } + + pub fn add_collection(&mut self, collection: &mut BuildCollection) { + self.collections.push(collection.clone()) + } +} + pub fn parse_params_str(params_strs: String) -> Result, String> { let mut unique_params: std::collections::HashSet = std::collections::HashSet::new();