Skip to content

Commit

Permalink
use MultiBuildCollection instead of allowing extend of BuildCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Sep 30, 2024
1 parent 64fbc29 commit e1e9506
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 47 deletions.
64 changes: 33 additions & 31 deletions src/directsketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use sourmash::signature::Signature;

use crate::utils::{
load_accession_info, load_gbassembly_info, parse_params_str, AccessionData, BuildCollection,
BuildManifest, GBAssemblyData, GenBankFileType, InputMolType,
BuildManifest, GBAssemblyData, GenBankFileType, InputMolType, MultiBuildCollection,
};
use reqwest::Url;

Expand Down Expand Up @@ -218,9 +218,9 @@ async fn dl_sketch_assembly_accession(
genomes_only: bool,
proteomes_only: bool,
download_only: bool,
) -> Result<(BuildCollection, Vec<FailedDownload>)> {
) -> Result<(MultiBuildCollection, Vec<FailedDownload>)> {
let retry_count = retry.unwrap_or(3); // Default retry count
let mut sig_collection = BuildCollection::new();
let mut built_sigs = MultiBuildCollection::new();
let mut failed = Vec::<FailedDownload>::new();

let name = accinfo.name;
Expand Down Expand Up @@ -255,7 +255,7 @@ async fn dl_sketch_assembly_accession(
failed.push(failed_download_protein);
}

return Ok((sig_collection, failed));
return Ok((built_sigs, failed));
}
};
let md5sum_url = GenBankFileType::Checksum.url(&base_url, &full_name);
Expand Down Expand Up @@ -311,7 +311,7 @@ async fn dl_sketch_assembly_accession(
match file_type {
GenBankFileType::Genomic => {
dna_sigs.build_sigs_from_data(data, "dna", name.clone(), file_name.clone())?;
sig_collection.extend_by_drain(dna_sigs);
built_sigs.add_collection(dna_sigs);
}
GenBankFileType::Protein => {
prot_sigs.build_sigs_from_data(
Expand All @@ -320,14 +320,14 @@ async fn dl_sketch_assembly_accession(
name.clone(),
file_name.clone(),
)?;
sig_collection.extend_by_drain(prot_sigs);
built_sigs.add_collection(prot_sigs);
}
_ => {} // Do nothing for other file types
};
}
}

Ok((sig_collection, failed))
Ok((built_sigs, failed))
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -342,9 +342,9 @@ async fn dl_sketch_url(
_genomes_only: bool,
_proteomes_only: bool,
download_only: bool,
) -> Result<(BuildCollection, Vec<FailedDownload>)> {
) -> Result<(MultiBuildCollection, Vec<FailedDownload>)> {
let retry_count = retry.unwrap_or(3); // Default retry count
let mut sig_collection = BuildCollection::new();
let mut built_sigs = MultiBuildCollection::new();
let mut failed = Vec::<FailedDownload>::new();

let name = accinfo.name;
Expand Down Expand Up @@ -375,7 +375,7 @@ async fn dl_sketch_url(
name.clone(),
filename.clone(),
)?;
sig_collection.extend_by_drain(dna_sigs);
built_sigs.add_collection(dna_sigs);
}
InputMolType::Protein => {
prot_sigs.build_sigs_from_data(
Expand All @@ -384,7 +384,7 @@ async fn dl_sketch_url(
name.clone(),
filename.clone(),
)?;
sig_collection.extend_by_drain(prot_sigs);
built_sigs.add_collection(prot_sigs);
}
};
}
Expand All @@ -402,7 +402,7 @@ async fn dl_sketch_url(
}
}

Ok((sig_collection, failed))
Ok((built_sigs, failed))
}

/// create zip file depending on batch size and index.
Expand All @@ -417,7 +417,7 @@ async fn create_or_get_zip_file(
}

pub fn zipwriter_handle(
mut recv_sigs: tokio::sync::mpsc::Receiver<BuildCollection>,
mut recv_sigs: tokio::sync::mpsc::Receiver<MultiBuildCollection>,
output_sigs: Option<String>,
error_sender: tokio::sync::mpsc::Sender<anyhow::Error>,
) -> tokio::task::JoinHandle<()> {
Expand All @@ -439,26 +439,28 @@ pub fn zipwriter_handle(
}
};

while let Some(mut sigcoll) = recv_sigs.recv().await {
while let Some(mut multibuildcoll) = recv_sigs.recv().await {
// write all sigs from sigcoll. Note that this method updates each record's internal location
match sigcoll
.async_write_sigs_to_zip(&mut zip_writer, &mut md5sum_occurrences)
.await
{
Ok(_) => {
file_count += sigcoll.size();
wrote_sigs = true;
}
Err(e) => {
let error = e.context("Error processing signature");
if error_sender.send(error).await.is_err() {
return;
for sigcoll in &mut multibuildcoll.collections {
match sigcoll
.async_write_sigs_to_zip(&mut zip_writer, &mut md5sum_occurrences)
.await
{
Ok(_) => {
file_count += sigcoll.size();
wrote_sigs = true;
}
Err(e) => {
let error = e.context("Error processing signature");
if error_sender.send(error).await.is_err() {
return;
}
}
}
// add all records from sigcoll manifest
zip_manifest.extend_from_manifest(&sigcoll.manifest);
file_count += sigcoll.size();
}
// add all records from sigcoll manifest
zip_manifest.extend_from_manifest(&sigcoll.manifest);
file_count += sigcoll.size();
}

if file_count > 0 {
Expand Down Expand Up @@ -595,7 +597,7 @@ pub async fn gbsketch(
create_dir_all(&download_path)?;
}
// create channels. buffer size here is 4 b/c we can do 3 downloads simultaneously
let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::<BuildCollection>(4);
let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::<MultiBuildCollection>(4);
let (send_failed, recv_failed) = tokio::sync::mpsc::channel::<FailedDownload>(4);
// Error channel for handling task errors
let (error_sender, error_receiver) = tokio::sync::mpsc::channel::<anyhow::Error>(1);
Expand Down Expand Up @@ -772,7 +774,7 @@ pub async fn urlsketch(
}

// create channels. buffer size here is 4 b/c we can do 3 downloads simultaneously
let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::<BuildCollection>(4);
let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::<MultiBuildCollection>(4);
let (send_failed, recv_failed) = tokio::sync::mpsc::channel::<FailedDownload>(4);
// Error channel for handling task errors
let (error_sender, error_receiver) = tokio::sync::mpsc::channel::<anyhow::Error>(1);
Expand Down
39 changes: 23 additions & 16 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ impl BuildManifest {
}
}

#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct BuildCollection {
pub manifest: BuildManifest,
pub sigs: Vec<Signature>,
Expand Down Expand Up @@ -553,20 +553,6 @@ impl BuildCollection {
self.sigs.push(sig);
}

pub fn extend(&mut self, other: BuildCollection) {
// Extend the manifest and signatures from another BuildCollection
self.manifest.records.extend(other.manifest.records);
self.sigs.extend(other.sigs);
}

pub fn extend_by_drain(&mut self, other: &mut BuildCollection) {
// Extend the manifest and signatures by draining from another BuildCollection
self.manifest
.records
.extend(other.manifest.records.drain(..));
self.sigs.extend(other.sigs.drain(..));
}

pub fn filter(&mut self, params_set: &HashSet<u64>) {
let mut index = 0;
while index < self.manifest.records.len() {
Expand Down Expand Up @@ -690,7 +676,7 @@ impl BuildCollection {
record.set_md5short(Some(sig.md5sum()[0..8].into()));
record.set_n_hashes(Some(sig.size()));

// what to set this to?
// note, this needs to be set when writing sigs
// record.set_internal_location("")
}
}
Expand Down Expand Up @@ -761,6 +747,27 @@ impl<'a> IntoIterator for &'a mut BuildCollection {
}
}

#[derive(Debug, Clone)]
pub struct MultiBuildCollection {
pub collections: Vec<BuildCollection>,
}

impl MultiBuildCollection {
pub fn new() -> Self {
MultiBuildCollection {
collections: Vec::new(),
}
}

pub fn is_empty(&self) -> bool {
self.collections.is_empty()
}

pub fn add_collection(&mut self, collection: &mut BuildCollection) {
self.collections.push(collection.clone())
}
}

pub fn parse_params_str(params_strs: String) -> Result<Vec<Params>, String> {
let mut unique_params: std::collections::HashSet<Params> = std::collections::HashSet::new();

Expand Down

0 comments on commit e1e9506

Please sign in to comment.