diff --git a/README.md b/README.md index 1f0a783..7b24325 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,11 @@ conda activate directsketch pip install sourmash_plugin_directsketch ``` +## Usage Considerations + +If you're building large databases (over 20k files), we highly recommend you use batched zipfiles (v0.4+) to facilitate restart. If you encounter unexpected failures and are using a single zipfile output (default), `gbsketch`/`urlsketch` will have to re-download and re-sketch all files. If you instead set a batch size using `--batch-size`, e.g. 10000, then `gbsketch`/`urlsketch` can load any batched zips that finished writing, and avoid re-generating those signatures. For `gbsketch`, the batch size represents the number of accessions included in each zip, with all signatures associated with an accession grouped within a single `zip`. For `urlsketch`, the batch size represents the number of total signatures included in each zip. Note that batches will use the `--output` file to build batched filenames, so if you provided `output.zip`, your batches will be `output.1.zip`, `output.2.zip`, etc. + + ## Running the commands ## `gbsketch` @@ -76,7 +81,7 @@ For reference: To test `gbsketch`, you can download a csv file and run: ``` curl -JLO https://raw.githubusercontent.com/sourmash-bio/sourmash_plugin_directsketch/main/tests/test-data/acc.csv -sourmash scripts gbsketch acc.csv -o test-gbsketch.zip -f out_fastas -k --failed test.failed.csv -p dna,k=21,k=31,scaled=1000,abund -p protein,k=10,scaled=100,abund -r 1 +sourmash scripts gbsketch acc.csv -o test-gbsketch.zip -f out_fastas -k --failed test.failed.csv --checksum-fail test.checksum-failed.csv -p dna,k=21,k=31,scaled=1000,abund -p protein,k=10,scaled=100,abund -r 1 ``` To check that the `zip` was created properly, you can run: ``` @@ -102,7 +107,9 @@ summary of sketches: Full Usage: ``` -usage: gbsketch [-h] [-q] [-d] [-o OUTPUT] [-f FASTAS] [-k] [--download-only] [--failed FAILED] [-p PARAM_STRING] [-c CORES] [-r RETRY_TIMES] [-g | -m] input_csv +usage: gbsketch [-h] [-q] [-d] [-o OUTPUT] [-f FASTAS] [--batch-size BATCH_SIZE] [-k] [--download-only] --failed FAILED --checksum-fail CHECKSUM_FAIL [-p PARAM_STRING] [-c CORES] + [-r RETRY_TIMES] [-g | -m] + input_csv download and sketch GenBank assembly datasets @@ -117,9 +124,14 @@ options: output zip file for the signatures -f FASTAS, --fastas FASTAS Write fastas here + --batch-size BATCH_SIZE + Write smaller zipfiles, each containing sigs associated with this number of accessions. This allows gbsketch to recover after unexpected failures, rather than needing to + restart sketching from scratch. Default: write all sigs to single zipfile. -k, --keep-fasta write FASTA files in addition to sketching. Default: do not write FASTA files --download-only just download genomes; do not sketch --failed FAILED csv of failed accessions and download links (should be mostly protein). + --checksum-fail CHECKSUM_FAIL + csv of accessions where the md5sum check failed or the md5sum file was improperly formatted or could not be downloaded -p PARAM_STRING, --param-string PARAM_STRING parameter string for sketching (default: k=31,scaled=1000) -c CORES, --cores CORES @@ -158,7 +170,9 @@ sourmash scripts urlsketch tests/test-data/acc-url.csv -o test-urlsketch.zip -f Full Usage: ``` -usage: urlsketch [-h] [-q] [-d] [-o OUTPUT] [-f FASTAS] [-k] [--download-only] [--failed FAILED] [-p PARAM_STRING] [-c CORES] [-r RETRY_TIMES] input_csv +usage: urlsketch [-h] [-q] [-d] [-o OUTPUT] [--batch-size BATCH_SIZE] [-f FASTAS] [-k] [--download-only] --failed FAILED [--checksum-fail CHECKSUM_FAIL] [-p PARAM_STRING] [-c CORES] + [-r RETRY_TIMES] + input_csv download and sketch GenBank assembly datasets @@ -171,12 +185,17 @@ options: -d, --debug provide debugging output -o OUTPUT, --output OUTPUT output zip file for the signatures + --batch-size BATCH_SIZE + Write smaller zipfiles, each containing sigs associated with this number of accessions. This allows urlsketch to recover after unexpected failures, rather than needing to + restart sketching from scratch. Default: write all sigs to single zipfile. -f FASTAS, --fastas FASTAS Write fastas here -k, --keep-fasta, --keep-fastq write FASTA/Q files in addition to sketching. Default: do not write FASTA files --download-only just download genomes; do not sketch - --failed FAILED csv of failed accessions and download links (should be mostly protein). + --failed FAILED csv of failed accessions and download links. + --checksum-fail CHECKSUM_FAIL + csv of accessions where the md5sum check failed. If not provided, md5sum failures will be written to the download failures file (no additional md5sum information). -p PARAM_STRING, --param-string PARAM_STRING parameter string for sketching (default: k=31,scaled=1000) -c CORES, --cores CORES diff --git a/src/directsketch.rs b/src/directsketch.rs index 0ac0449..13b958b 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -3,9 +3,11 @@ use async_zip::base::write::ZipFileWriter; use camino::Utf8PathBuf as PathBuf; use regex::Regex; use reqwest::Client; -use std::collections::HashMap; +use sourmash::collection::Collection; +use std::cmp::max; +use std::collections::{HashMap, HashSet}; use std::fs::{self, create_dir_all}; -use std::path::Path; +use std::panic; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::fs::File; @@ -18,6 +20,7 @@ use pyo3::prelude::*; use crate::utils::{ load_accession_info, load_gbassembly_info, parse_params_str, AccessionData, BuildCollection, BuildManifest, GBAssemblyData, GenBankFileType, InputMolType, MultiBuildCollection, + MultiCollection, }; use reqwest::Url; @@ -483,13 +486,91 @@ async fn dl_sketch_url( Ok((built_sigs, download_failures, checksum_failures)) } +// Load existing batch files into MultiCollection, skipping corrupt files +async fn load_existing_zip_batches(outpath: &PathBuf) -> Result<(MultiCollection, usize)> { + // Remove the .zip extension to get the base name + let outpath_base = outpath.with_extension(""); + + // Regex to match the exact zip filename and its batches (e.g., "outpath.zip", "outpath.1.zip", "outpath.2.zip", etc.) + let zip_file_pattern = Regex::new(&format!( + r"^{}(?:\.(\d+))?\.zip$", + regex::escape(outpath_base.file_name().unwrap()) + )) + .unwrap(); + + // Initialize a vector to store valid collections + let mut collections = Vec::new(); + let mut highest_batch = 0; // Track the highest batch number + + // Read the directory containing the outpath + let dir = outpath_base + .parent() + .ok_or_else(|| anyhow::anyhow!("Could not get parent directory"))?; + let mut dir_entries = tokio::fs::read_dir(dir).await?; + + // Scan through all files in the directory + while let Some(entry) = dir_entries.next_entry().await? { + let entry_path: PathBuf = entry.path().try_into()?; + + if let Some(file_name) = entry_path.file_name() { + // Check if the file matches the base zip file or any batched zip file (outpath.zip, outpath.1.zip, etc.) + if let Some(captures) = zip_file_pattern.captures(file_name) { + // Wrap the `from_zipfile` call in `catch_unwind` to prevent panic propagation + let result = panic::catch_unwind(|| Collection::from_zipfile(&entry_path)); + match result { + Ok(Ok(collection)) => { + // Successfully loaded the collection, push to `collections` + collections.push(collection); + + // Extract the batch number (if it exists) and update the highest_batch + if let Some(batch_str) = captures.get(1) { + if let Ok(batch_num) = batch_str.as_str().parse::() { + highest_batch = max(highest_batch, batch_num); + } + } + } + Ok(Err(e)) => { + // Handle the case where `from_zipfile` returned an error + eprintln!( + "Warning: Failed to load zip file '{}'. Error: {:?}", + entry_path, e + ); + continue; // Skip the file and continue + } + Err(_) => { + // The code inside `from_zipfile` panicked + eprintln!("Warning: Invalid zip file '{}'; skipping.", entry_path); + continue; // Skip the file and continue + } + } + } + } + } + // Return the loaded MultiCollection and the max batch index, even if no collections were found + Ok((MultiCollection::new(collections), highest_batch)) +} + /// create zip file depending on batch size and index. async fn create_or_get_zip_file( outpath: &PathBuf, + batch_size: usize, + batch_index: usize, ) -> Result>, anyhow::Error> { - let file = File::create(&outpath) + let batch_outpath = if batch_size == 0 { + // If batch size is zero, use provided outpath (contains .zip extension) + outpath.clone() + } else { + // Otherwise, modify outpath to include the batch index + let outpath_base = outpath.with_extension(""); // remove .zip extension + outpath_base.with_file_name(format!( + "{}.{}.zip", + outpath_base.file_stem().unwrap(), + batch_index + )) + }; + let file = File::create(&batch_outpath) .await - .with_context(|| format!("Failed to create file: {:?}", outpath))?; + .with_context(|| format!("Failed to create file: {:?}", batch_outpath))?; Ok(ZipFileWriter::with_tokio(file)) } @@ -497,6 +578,8 @@ async fn create_or_get_zip_file( pub fn zipwriter_handle( mut recv_sigs: tokio::sync::mpsc::Receiver, output_sigs: Option, + batch_size: usize, // Tunable batch size + mut batch_index: usize, // starting batch index error_sender: tokio::sync::mpsc::Sender, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { @@ -504,57 +587,87 @@ pub fn zipwriter_handle( let mut zip_manifest = BuildManifest::new(); let mut wrote_sigs = false; let mut file_count = 0; // Count of files in the current batch + let mut zip_writer = None; if let Some(outpath) = output_sigs { let outpath: PathBuf = outpath.into(); - // Create the initial zip file - let mut zip_writer = match create_or_get_zip_file(&outpath).await { - Ok(writer) => writer, - Err(e) => { - let _ = error_sender.send(e).await; - return; - } - }; - while let Some(mut multibuildcoll) = recv_sigs.recv().await { - // write all sigs from sigcoll. Note that this method updates each record's internal location - for sigcoll in &mut multibuildcoll.collections { - match sigcoll - .async_write_sigs_to_zip(&mut zip_writer, &mut md5sum_occurrences) - .await - { - Ok(_) => { - file_count += sigcoll.size(); - wrote_sigs = true; - } - Err(e) => { - let error = e.context("Error processing signature"); - if error_sender.send(error).await.is_err() { + if zip_writer.is_none() { + // create zip file if needed + zip_writer = + match create_or_get_zip_file(&outpath, batch_size, batch_index).await { + Ok(writer) => Some(writer), + Err(e) => { + let _ = error_sender.send(e).await; return; } + }; + } + + if let Some(zip_writer) = zip_writer.as_mut() { + // write all sigs from sigcoll. Note that this method updates each record's internal location + for sigcoll in &mut multibuildcoll.collections { + match sigcoll + .async_write_sigs_to_zip(zip_writer, &mut md5sum_occurrences) + .await + { + Ok(_) => { + file_count += sigcoll.size(); + wrote_sigs = true; + } + Err(e) => { + let error = e.context("Error processing signature"); + if error_sender.send(error).await.is_err() { + return; + } + } } + // Add all records from sigcoll manifest + zip_manifest.extend_from_manifest(&sigcoll.manifest); } - // add all records from sigcoll manifest - zip_manifest.extend_from_manifest(&sigcoll.manifest); - file_count += sigcoll.size(); + } + + // if batch size is non-zero and is reached, close the current zip + if batch_size > 0 && file_count >= batch_size { + eprintln!("writing batch {}", batch_index); + if let Some(mut zip_writer) = zip_writer.take() { + if let Err(e) = zip_manifest + .async_write_manifest_to_zip(&mut zip_writer) + .await + { + let _ = error_sender.send(e).await; + } + if let Err(e) = zip_writer.close().await { + let error = anyhow::Error::new(e).context("Failed to close ZIP file"); + let _ = error_sender.send(error).await; + return; + } + } + // Start a new batch + batch_index += 1; + file_count = 0; + zip_manifest.clear(); + zip_writer = None; // reset zip_writer so a new zip will be created when needed } } if file_count > 0 { - // Write the final manifest - if let Err(e) = zip_manifest - .async_write_manifest_to_zip(&mut zip_writer) - .await - { - let _ = error_sender.send(e).await; - } + // write the final manifest + if let Some(mut zip_writer) = zip_writer.take() { + if let Err(e) = zip_manifest + .async_write_manifest_to_zip(&mut zip_writer) + .await + { + let _ = error_sender.send(e).await; + } - // Close the zip file for the final batch - if let Err(e) = zip_writer.close().await { - let error = anyhow::Error::new(e).context("Failed to close ZIP file"); - let _ = error_sender.send(error).await; - return; + // close final zip file + if let Err(e) = zip_writer.close().await { + let error = anyhow::Error::new(e).context("Failed to close ZIP file"); + let _ = error_sender.send(error).await; + return; + } } } if !wrote_sigs { @@ -572,7 +685,7 @@ pub fn zipwriter_handle( pub fn failures_handle( failed_csv: String, mut recv_failed: tokio::sync::mpsc::Receiver, - error_sender: tokio::sync::mpsc::Sender, // Additional parameter for error channel + error_sender: tokio::sync::mpsc::Sender, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { match File::create(&failed_csv).await { @@ -724,17 +837,39 @@ pub async fn gbsketch( genomes_only: bool, proteomes_only: bool, download_only: bool, + batch_size: u32, output_sigs: Option, ) -> Result<(), anyhow::Error> { - // if sig output provided but doesn't end in zip, bail + let batch_size = batch_size as usize; + let mut batch_index = 1; + let mut name_params_map: HashMap> = HashMap::new(); + let mut filter = false; if let Some(ref output_sigs) = output_sigs { - if Path::new(&output_sigs) - .extension() - .map_or(true, |ext| ext != "zip") - { + // Create outpath from output_sigs + let outpath = PathBuf::from(output_sigs); + + // Check if the extension is "zip" + if outpath.extension().map_or(true, |ext| ext != "zip") { bail!("Output must be a zip file."); } + // find and read any existing sigs + let (existing_sigs, max_existing_batch_index) = load_existing_zip_batches(&outpath).await?; + // Check if there are any existing batches to process + if !existing_sigs.is_empty() { + name_params_map = existing_sigs.buildparams_hashmap(); + + batch_index = max_existing_batch_index + 1; + eprintln!( + "Found {} existing valid zip batch(es). Starting new sig writing at batch {}", + max_existing_batch_index, batch_index + ); + filter = true; + } else { + // No existing batches, skipping signature filtering + eprintln!("No valid existing signature batches found; building all signatures."); + } } + // set up fasta download path let download_path = PathBuf::from(fasta_location); if !download_path.exists() { @@ -751,7 +886,13 @@ pub async fn gbsketch( // Set up collector/writing tasks let mut handles = Vec::new(); - let sig_handle = zipwriter_handle(recv_sigs, output_sigs, error_sender.clone()); + let sig_handle = zipwriter_handle( + recv_sigs, + output_sigs, + batch_size, + batch_index, + error_sender.clone(), + ); let failures_handle = failures_handle(failed_csv, recv_failed, error_sender.clone()); let checksum_failures_handle = checksum_failures_handle( @@ -784,10 +925,9 @@ pub async fn gbsketch( bail!("Failed to parse params string: {}", e); } }; - // let dna_sig_templates = build_siginfo(¶ms_vec, "DNA"); - let dna_template_collection = BuildCollection::from_params(¶ms_vec, "DNA"); + let dna_template_collection = BuildCollection::from_buildparams(¶ms_vec, "DNA"); // prot will build protein, dayhoff, hp - let prot_template_collection = BuildCollection::from_params(¶ms_vec, "protein"); + let prot_template_collection = BuildCollection::from_buildparams(¶ms_vec, "protein"); let mut genomes_only = genomes_only; let mut proteomes_only = proteomes_only; @@ -821,6 +961,20 @@ pub async fn gbsketch( for (i, accinfo) in accession_info.into_iter().enumerate() { py.check_signals()?; // If interrupted, return an Err automatically + + let mut dna_sigs = dna_template_collection.clone(); + let mut prot_sigs = prot_template_collection.clone(); + + // filter template sigs based on existing sigs + if filter { + if let Some(existing_paramset) = name_params_map.get(&accinfo.name) { + // If the key exists, filter template sigs + dna_sigs.filter(existing_paramset); + prot_sigs.filter(existing_paramset); + } + } + + // clone remaining utilities let semaphore_clone = Arc::clone(&semaphore); let client_clone = Arc::clone(&client); let send_sigs = send_sigs.clone(); @@ -828,8 +982,6 @@ pub async fn gbsketch( let checksum_send_failed = send_failed_checksums.clone(); let download_path_clone = download_path.clone(); // Clone the path for each task let send_errors = error_sender.clone(); - let mut dna_sigs = dna_template_collection.clone(); - let mut prot_sigs = prot_template_collection.clone(); tokio::spawn(async move { let _permit = semaphore_clone.acquire().await; @@ -916,18 +1068,40 @@ pub async fn urlsketch( fasta_location: String, keep_fastas: bool, download_only: bool, + batch_size: u32, output_sigs: Option, failed_checksums_csv: Option, ) -> Result<(), anyhow::Error> { - // if sig output provided but doesn't end in zip, bail + let batch_size = batch_size as usize; + let mut batch_index = 1; + let mut name_params_map: HashMap> = HashMap::new(); + let mut filter = false; if let Some(ref output_sigs) = output_sigs { - if Path::new(&output_sigs) - .extension() - .map_or(true, |ext| ext != "zip") - { + // Create outpath from output_sigs + let outpath = PathBuf::from(output_sigs); + + // Check if the extension is "zip" + if outpath.extension().map_or(true, |ext| ext != "zip") { bail!("Output must be a zip file."); } + // find and read any existing sigs + let (existing_sigs, max_existing_batch_index) = load_existing_zip_batches(&outpath).await?; + // Check if there are any existing batches to process + if !existing_sigs.is_empty() { + name_params_map = existing_sigs.buildparams_hashmap(); + + batch_index = max_existing_batch_index + 1; + eprintln!( + "Found {} existing zip batches. Starting new sig writing at batch {}", + max_existing_batch_index, batch_index + ); + filter = true; + } else { + // No existing batches, skipping signature filtering + eprintln!("No existing signature batches found; building all signatures."); + } } + // set up fasta download path let download_path = PathBuf::from(fasta_location); if !download_path.exists() { @@ -945,7 +1119,13 @@ pub async fn urlsketch( // Set up collector/writing tasks let mut handles = Vec::new(); - let sig_handle = zipwriter_handle(recv_sigs, output_sigs, error_sender.clone()); + let sig_handle = zipwriter_handle( + recv_sigs, + output_sigs, + batch_size, + batch_index, + error_sender.clone(), + ); let failures_handle = failures_handle(failed_csv, recv_failed, error_sender.clone()); @@ -984,8 +1164,8 @@ pub async fn urlsketch( bail!("Failed to parse params string: {}", e); } }; - let dna_template_collection = BuildCollection::from_params(¶ms_vec, "DNA"); - let prot_template_collection = BuildCollection::from_params(¶ms_vec, "protein"); + let dna_template_collection = BuildCollection::from_buildparams(¶ms_vec, "DNA"); + let prot_template_collection = BuildCollection::from_buildparams(¶ms_vec, "protein"); let mut genomes_only = false; let mut proteomes_only = false; @@ -1019,6 +1199,18 @@ pub async fn urlsketch( for (i, accinfo) in accession_info.into_iter().enumerate() { py.check_signals()?; // If interrupted, return an Err automatically + let mut dna_sigs = dna_template_collection.clone(); + let mut prot_sigs = prot_template_collection.clone(); + + // filter template sigs based on existing sigs + if filter { + if let Some(existing_paramset) = name_params_map.get(&accinfo.name) { + // If the key exists, filter template sigs + dna_sigs.filter(existing_paramset); + prot_sigs.filter(existing_paramset); + } + } + let semaphore_clone = Arc::clone(&semaphore); let client_clone = Arc::clone(&client); let send_sigs = send_sigs.clone(); @@ -1027,9 +1219,6 @@ pub async fn urlsketch( let download_path_clone = download_path.clone(); // Clone the path for each task let send_errors = error_sender.clone(); - let mut dna_sigs = dna_template_collection.clone(); - let mut prot_sigs = prot_template_collection.clone(); - tokio::spawn(async move { let _permit = semaphore_clone.acquire().await; // progress report when the permit is available and processing begins diff --git a/src/lib.rs b/src/lib.rs index 387b3fc..648ae69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,7 @@ fn set_tokio_thread_pool(num_threads: usize) -> PyResult { #[pyfunction] #[allow(clippy::too_many_arguments)] -#[pyo3(signature = (input_csv, param_str, failed_csv, failed_checksums, retry_times, fasta_location, keep_fastas, genomes_only, proteomes_only, download_only, output_sigs=None))] +#[pyo3(signature = (input_csv, param_str, failed_csv, failed_checksums, retry_times, fasta_location, keep_fastas, genomes_only, proteomes_only, download_only, batch_size, output_sigs=None))] fn do_gbsketch( py: Python, input_csv: String, @@ -62,6 +62,7 @@ fn do_gbsketch( genomes_only: bool, proteomes_only: bool, download_only: bool, + batch_size: u32, output_sigs: Option, ) -> anyhow::Result { match directsketch::gbsketch( @@ -76,6 +77,7 @@ fn do_gbsketch( genomes_only, proteomes_only, download_only, + batch_size, output_sigs, ) { Ok(_) => Ok(0), @@ -88,7 +90,7 @@ fn do_gbsketch( #[pyfunction] #[allow(clippy::too_many_arguments)] -#[pyo3(signature = (input_csv, param_str, failed_csv, retry_times, fasta_location, keep_fastas, download_only, output_sigs=None, failed_checksums=None))] +#[pyo3(signature = (input_csv, param_str, failed_csv, retry_times, fasta_location, keep_fastas, download_only, batch_size, output_sigs=None, failed_checksums=None))] fn do_urlsketch( py: Python, input_csv: String, @@ -98,6 +100,7 @@ fn do_urlsketch( fasta_location: String, keep_fastas: bool, download_only: bool, + batch_size: u32, output_sigs: Option, failed_checksums: Option, ) -> anyhow::Result { @@ -110,6 +113,7 @@ fn do_urlsketch( fasta_location, keep_fastas, download_only, + batch_size, output_sigs, failed_checksums, ) { diff --git a/src/python/sourmash_plugin_directsketch/__init__.py b/src/python/sourmash_plugin_directsketch/__init__.py index 4254d6b..da88548 100644 --- a/src/python/sourmash_plugin_directsketch/__init__.py +++ b/src/python/sourmash_plugin_directsketch/__init__.py @@ -4,6 +4,7 @@ from sourmash.logging import notify from sourmash.plugins import CommandLinePlugin import importlib.metadata +import argparse from . import sourmash_plugin_directsketch @@ -32,6 +33,12 @@ def set_thread_pool(user_cores): actual_tokio_cores = sourmash_plugin_directsketch.set_tokio_thread_pool(num_threads) return actual_tokio_cores +def non_negative_int(value): + ivalue = int(value) + if ivalue < 0: + raise argparse.ArgumentTypeError(f"Batch size cannot be negative (input value: {value})") + return ivalue + class Download_and_Sketch_Assemblies(CommandLinePlugin): command = 'gbsketch' description = 'download and sketch GenBank assembly datasets' @@ -40,9 +47,13 @@ def __init__(self, p): super().__init__(p) p.add_argument('input_csv', help="a txt file or csv file containing accessions in the first column") p.add_argument('-o', '--output', default=None, - help='output zip file for the signatures') + help="output zip file for the signatures. Must end with '.zip'") p.add_argument('-f', '--fastas', help='Write fastas here', default = '.') + p.add_argument('--batch-size', type=non_negative_int, default = 0, + help='Write smaller zipfiles, each containing sigs associated with this number of accessions. \ + This allows gbsketch to recover after unexpected failures, rather than needing to \ + restart sketching from scratch. Default: write all sigs to single zipfile.') p.add_argument('-k', '--keep-fasta', action='store_true', help="write FASTA files in addition to sketching. Default: do not write FASTA files") p.add_argument('--download-only', help='just download genomes; do not sketch', action='store_true') @@ -92,6 +103,7 @@ def main(self, args): args.genomes_only, args.proteomes_only, args.download_only, + args.batch_size, args.output) if status == 0: @@ -113,6 +125,10 @@ def __init__(self, p): p.add_argument('input_csv', help="a txt file or csv file containing accessions in the first column") p.add_argument('-o', '--output', default=None, help='output zip file for the signatures') + p.add_argument('--batch-size', type=non_negative_int, default = 0, + help='Write smaller zipfiles, each containing sigs associated with this number of accessions. \ + This allows urlsketch to recover after unexpected failures, rather than needing to \ + restart sketching from scratch. Default: write all sigs to single zipfile.') p.add_argument('-f', '--fastas', help='Write fastas here', default = '.') p.add_argument('-k', '--keep-fasta', '--keep-fastq', action='store_true', @@ -159,6 +175,7 @@ def main(self, args): args.fastas, args.keep_fasta, args.download_only, + args.batch_size, args.output, args.checksum_fail) diff --git a/src/utils.rs b/src/utils.rs index a4ff045..3c84f1a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -9,6 +9,7 @@ use needletail::{parse_fastx_file, parse_fastx_reader}; use reqwest::Url; use serde::Serialize; use sourmash::cmd::ComputeParameters; +use sourmash::collection::Collection; use sourmash::manifest::Record; use sourmash::signature::Signature; use std::collections::hash_map::DefaultHasher; @@ -287,7 +288,7 @@ pub fn load_accession_info( } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Params { +pub struct BuildParams { pub ksize: u32, pub track_abundance: bool, pub num: u32, @@ -299,7 +300,7 @@ pub struct Params { pub is_dna: bool, } -impl Hash for Params { +impl Hash for BuildParams { fn hash(&self, state: &mut H) { self.ksize.hash(state); self.track_abundance.hash(state); @@ -313,15 +314,21 @@ impl Hash for Params { } } -impl Params { +impl BuildParams { + pub fn calculate_hash(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.hash(&mut hasher); // Use the Hash trait implementation + hasher.finish() // Return the final u64 hash value + } + pub fn from_record(record: &Record) -> Self { let moltype = record.moltype(); // Get the moltype (HashFunctions enum) - Params { + BuildParams { ksize: record.ksize(), track_abundance: record.with_abundance(), - num: record.num().clone(), - scaled: record.scaled().clone(), + num: *record.num(), + scaled: *record.scaled(), seed: 42, is_protein: moltype.protein(), is_dayhoff: moltype.dayhoff(), @@ -385,7 +392,7 @@ where } impl BuildRecord { - pub fn from_params(param: &Params, input_moltype: &str) -> Self { + pub fn from_buildparams(param: &BuildParams, input_moltype: &str) -> Self { // Calculate the hash of Params let mut hasher = DefaultHasher::new(); param.hash(&mut hasher); @@ -484,7 +491,25 @@ impl BuildManifest { } } -#[derive(Debug, Clone)] +impl<'a> IntoIterator for &'a BuildManifest { + type Item = &'a BuildRecord; + type IntoIter = std::slice::Iter<'a, BuildRecord>; + + fn into_iter(self) -> Self::IntoIter { + self.records.iter() + } +} + +impl<'a> IntoIterator for &'a mut BuildManifest { + type Item = &'a mut BuildRecord; + type IntoIter = std::slice::IterMut<'a, BuildRecord>; + + fn into_iter(self) -> Self::IntoIter { + self.records.iter_mut() + } +} + +#[derive(Debug, Default, Clone)] pub struct BuildCollection { pub manifest: BuildManifest, pub sigs: Vec, @@ -506,7 +531,7 @@ impl BuildCollection { self.manifest.size() } - pub fn from_params(params: &[Params], input_moltype: &str) -> Self { + pub fn from_buildparams(params: &[BuildParams], input_moltype: &str) -> Self { let mut collection = BuildCollection::new(); for param in params.iter().cloned() { @@ -516,7 +541,7 @@ impl BuildCollection { collection } - pub fn add_template_sig(&mut self, param: Params, input_moltype: &str) { + pub fn add_template_sig(&mut self, param: BuildParams, input_moltype: &str) { // Check the input_moltype against Params to decide if this should be added match input_moltype { "dna" | "DNA" if !param.is_dna => return, // Skip if it's not the correct moltype @@ -546,7 +571,7 @@ impl BuildCollection { let sig = Signature::from_params(&cp); // Create the BuildRecord using from_param - let template_record = BuildRecord::from_params(¶m, input_moltype); + let template_record = BuildRecord::from_buildparams(¶m, input_moltype); // Add the record and signature to the collection self.manifest.records.push(template_record); @@ -667,12 +692,12 @@ impl BuildCollection { for (record, sig) in self.iter_mut() { // update signature name, filename sig.set_name(name.as_str()); - sig.set_filename(&filename.as_str()); + sig.set_filename(filename.as_str()); // update record: set name, filename, md5sum, n_hashes record.set_name(Some(name.clone())); record.set_filename(Some(filename.clone())); - record.set_md5(Some(sig.md5sum().into())); + record.set_md5(Some(sig.md5sum())); record.set_md5short(Some(sig.md5sum()[0..8].into())); record.set_n_hashes(Some(sig.size())); @@ -768,8 +793,9 @@ impl MultiBuildCollection { } } -pub fn parse_params_str(params_strs: String) -> Result, String> { - let mut unique_params: std::collections::HashSet = std::collections::HashSet::new(); +pub fn parse_params_str(params_strs: String) -> Result, String> { + let mut unique_params: std::collections::HashSet = + std::collections::HashSet::new(); // split params_strs by _ and iterate over each param for p_str in params_strs.split('_').collect::>().iter() { @@ -827,7 +853,7 @@ pub fn parse_params_str(params_strs: String) -> Result, String> { } for &k in &ksizes { - let param = Params { + let param = BuildParams { ksize: k, track_abundance, num, @@ -844,3 +870,319 @@ pub fn parse_params_str(params_strs: String) -> Result, String> { Ok(unique_params.into_iter().collect()) } + +// this should be replaced with branchwater's MultiCollection when it's ready +#[derive(Clone)] +pub struct MultiCollection { + collections: Vec, +} + +impl MultiCollection { + pub fn new(collections: Vec) -> Self { + Self { collections } + } + + pub fn is_empty(&self) -> bool { + self.collections.is_empty() + } + + pub fn buildparams_hashmap(&self) -> HashMap> { + let mut name_params_map = HashMap::new(); + + // Iterate over all collections in MultiCollection + for collection in &self.collections { + // Iterate over all records in the current collection + for (_, record) in collection.iter() { + // Get the record's name or fasta filename + let record_name = record.name().clone(); + + // Calculate the hash of the Params for the current record + let params_hash = BuildParams::from_record(record).calculate_hash(); + + // If the name is already in the HashMap, extend the existing HashSet + // Otherwise, create a new HashSet and insert the hashed Params + name_params_map + .entry(record_name) + .or_insert_with(HashSet::new) // Create a new HashSet if the key doesn't exist + .insert(params_hash); // Insert the hashed Params into the HashSet + } + } + + name_params_map + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_buildparams_consistent_hashing() { + let params1 = BuildParams { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params2 = BuildParams { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let hash1 = params1.calculate_hash(); + let hash2 = params2.calculate_hash(); + let hash3 = params2.calculate_hash(); + + // Check that the hash for two identical Params is the same + assert_eq!(hash1, hash2, "Hashes for identical Params should be equal"); + + assert_eq!( + hash2, hash3, + "Hashes for the same Params should be consistent across multiple calls" + ); + } + + #[test] + fn test_buildparams_hashing_different() { + let params1 = BuildParams { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params2 = BuildParams { + ksize: 21, // Changed ksize + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let hash1 = params1.calculate_hash(); + let hash2 = params2.calculate_hash(); + + // Check that the hash for different Params is different + assert_ne!( + hash1, hash2, + "Hashes for different Params should not be equal" + ); + } + + #[test] + fn test_buildparams_generated_from_record() { + // load signature + build record + let mut filename = Utf8PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/GCA_000175535.1.sig.gz"); + let path = filename.clone(); + + let file = std::fs::File::open(filename).unwrap(); + let mut reader = std::io::BufReader::new(file); + let sigs = Signature::load_signatures( + &mut reader, + Some(31), + Some("DNA".try_into().unwrap()), + None, + ) + .unwrap(); + + assert_eq!(sigs.len(), 1); + + let sig = sigs.get(0).unwrap(); + let record = Record::from_sig(sig, path.as_str()); + + // create the expected Params based on the Record data + let expected_params = BuildParams { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + // // Generate the Params from the Record using the from_record method + let generated_params = BuildParams::from_record(&record[0]); + + // // Assert that the generated Params match the expected Params + assert_eq!( + generated_params, expected_params, + "Generated Params did not match the expected Params" + ); + + // // Calculate the hash for the expected Params + let expected_hash = expected_params.calculate_hash(); + + // // Calculate the hash for the generated Params + let generated_hash = generated_params.calculate_hash(); + + // // Assert that the hash for the generated Params matches the expected Params hash + assert_eq!( + generated_hash, expected_hash, + "Hash of generated Params did not match the hash of expected Params" + ); + } + + #[test] + fn test_filter_removes_matching_buildparams() { + let params1 = BuildParams { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params2 = BuildParams { + ksize: 21, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params3 = BuildParams { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 2000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params_list = [params1.clone(), params2.clone(), params3.clone()]; + let mut build_collection = BuildCollection::from_buildparams(¶ms_list, "DNA"); + + let mut params_set = HashSet::new(); + params_set.insert(params1.calculate_hash()); + params_set.insert(params3.calculate_hash()); + + // Call the filter method + build_collection.filter(¶ms_set); + + // Check that the records and signatures with matching params are removed + assert_eq!( + build_collection.manifest.records.len(), + 1, + "Only one record should remain after filtering" + ); + assert_eq!( + build_collection.sigs.len(), + 1, + "Only one signature should remain after filtering" + ); + + // Check that the remaining record is the one with hashed_params = 456 + let h2 = params2.calculate_hash(); + assert_eq!( + build_collection.manifest.records[0].hashed_params, h2, + "The remaining record should have hashed_params {}", + h2 + ); + } + + #[test] + fn test_buildparams_hashmap() { + // read in zipfiles to build a MultiCollection + // load signature + build record + let mut filename = Utf8PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/GCA_000961135.2.sig.zip"); + let path = filename.clone(); + + let mut collections = Vec::new(); + let coll = Collection::from_zipfile(&path).unwrap(); + collections.push(coll); + let mc = MultiCollection::new(collections); + + // Call build_params_hashmap + let name_params_map = mc.buildparams_hashmap(); + + // Check that the HashMap contains the correct names + assert_eq!( + name_params_map.len(), + 1, + "There should be 1 unique names in the map" + ); + + let mut hashed_params = Vec::new(); + for (name, params_set) in name_params_map.iter() { + eprintln!("Name: {}", name); + for param_hash in params_set { + eprintln!(" Param Hash: {}", param_hash); + hashed_params.push(param_hash); + } + } + + let expected_params1 = BuildParams { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let expected_params2 = BuildParams { + ksize: 21, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let expected_hash1 = expected_params1.calculate_hash(); + let expected_hash2 = expected_params2.calculate_hash(); + + assert!( + hashed_params.contains(&&expected_hash1), + "Expected hash1 should be in the hashed_params" + ); + assert!( + hashed_params.contains(&&expected_hash2), + "Expected hash2 should be in the hashed_params" + ); + } +} diff --git a/tests/test_gbsketch.py b/tests/test_gbsketch.py index 84d5e9e..55a037b 100644 --- a/tests/test_gbsketch.py +++ b/tests/test_gbsketch.py @@ -2,6 +2,7 @@ Tests for gbsketch """ import os +import csv import pytest import sourmash @@ -577,3 +578,243 @@ def test_gbsketch_protein_dayhoff_hp(runtmp): assert download_filename == "GCA_000175535.1_protein.faa.gz" assert url == "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/175/535/GCA_000175535.1_ASM17553v1/GCA_000175535.1_ASM17553v1_protein.faa.gz" + +def test_gbsketch_simple_batched_single(runtmp, capfd): + # make sure both sigs associated with same acc end up in same zip + acc_csv = get_test_data('acc.csv') + acc1 = runtmp.output('acc1.csv') + # open acc.csv with csv dictreader and keep accession= GCA_000961135.2 line + with open(acc_csv, 'r') as inF, open(acc1, 'w', newline='') as outF: + r = csv.DictReader(inF) + w = csv.DictWriter(outF, fieldnames=r.fieldnames) + w.writeheader() + for row in r: + if row['accession'] == "GCA_000961135.2": + w.writerow(row) + + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + + sig1 = get_test_data('GCA_000961135.2.sig.gz') + sig2 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=30, select_moltype='protein') + + runtmp.sourmash('scripts', 'gbsketch', acc1, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + + expected_siginfo = { + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + (ss1.name, ss2.md5sum(), ss2.minhash.moltype), # ss1 name b/c of how it's written in acc.csv + } + + # Collect the actual signature information from all the output files + all_siginfo = set() + idx = sourmash.load_file_as_index(out1) + sigs = list(idx.signatures()) + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + + # Assert that all expected signatures are found + assert all_siginfo == expected_siginfo + + +def test_gbsketch_simple_batched_multiple(runtmp, capfd): + acc_csv = get_test_data('acc.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + # why does this need ksize =30 and not ksize = 10!??? + ss3 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + runtmp.sourmash('scripts', 'gbsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + + expected_siginfo = { + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + (ss2.name, ss2.md5sum(), ss2.minhash.moltype), + (ss2.name, ss3.md5sum(), ss3.minhash.moltype), # ss2 name b/c of how it's written in acc.csv + } + + # Collect the actual signature information from all the output files + all_siginfo = set() + for out_file in [out1, out2]: + idx = sourmash.load_file_as_index(out_file) + sigs = list(idx.signatures()) + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + + # Assert that all expected signatures are found (ignoring order) + assert all_siginfo == expected_siginfo + + +def test_gbsketch_simple_batch_restart(runtmp, capfd): + acc_csv = get_test_data('acc.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + out3 = runtmp.output('simple.3.zip') + + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss3 = sourmash.load_one_signature(sig2, ksize=21) + # why does this need ksize =30 and not ksize = 10!??? + ss4 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + # first, cat sig2 into an output file that will trick gbsketch into thinking it's a prior batch + runtmp.sourmash('sig', 'cat', sig2, '-o', out1) + assert os.path.exists(out1) + + runtmp.sourmash('scripts', 'gbsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) + assert os.path.exists(out3) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + + # # we created this one with sig cat + idx = sourmash.load_file_as_index(out1) + sigs = list(idx.signatures()) + assert len(sigs) == 2 + for sig in sigs: + assert sig.name == ss2.name + assert ss2.md5sum() in [ss2.md5sum(), ss3.md5sum()] + + # # these were created with gbsketch + expected_siginfo = { + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + (ss4.name, ss4.md5sum(), ss4.minhash.moltype), + } + + # Collect actual signature information from gbsketch zip batches + all_siginfo = set() + for out_file in [out2, out3]: + idx = sourmash.load_file_as_index(out_file) + sigs = list(idx.signatures()) + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + + # Assert that all expected signatures are found (ignoring order) + assert all_siginfo == expected_siginfo + + +def test_gbsketch_negative_batch_size(runtmp): + # negative int provided for batch size + acc_csv = runtmp.output('acc.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + + with pytest.raises(utils.SourmashCommandFailed): + runtmp.sourmash('scripts', 'gbsketch', acc_csv, + '--failed', failed, '-r', '1', '--batch-size', '-2', + '--param-str', "dna,k=31,scaled=1000") + + assert "Batch size cannot be negative (input value: -2)" in runtmp.last_result.err + + +def test_gbsketch_simple_batch_restart_with_incomplete_zip(runtmp, capfd): + # test restart with complete + incomplete zipfile batches + acc_csv = get_test_data('acc.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + out3 = runtmp.output('simple.3.zip') + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss3 = sourmash.load_one_signature(sig2, ksize=21) + ss4 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + # first, cat sig2 into an output file that will trick gbsketch into thinking it's a prior batch + runtmp.sourmash('sig', 'cat', sig2, '-o', out1) + assert os.path.exists(out1) + + # Create an invalid zip file for out2 + with open(out2, 'wb') as f: + f.write(b"This is not a valid zip file!") + + assert os.path.exists(out2) + + runtmp.sourmash('scripts', 'gbsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) # Should be overwritten + assert os.path.exists(out3) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + assert f"Warning: Invalid zip file '{out2}'; skipping." in captured.err + + # we created this one with sig cat + idx = sourmash.load_file_as_index(out1) + sigs = list(idx.signatures()) + assert len(sigs) == 2 + for sig in sigs: + assert sig.name == ss2.name + assert ss2.md5sum() in [ss2.md5sum(), ss3.md5sum()] + + # these were created with gbsketch (out2 should have been overwritten) + expected_siginfo = { + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + (ss4.name, ss4.md5sum(), ss4.minhash.moltype), + } + + # Collect actual signature information from gbsketch zip batches + all_siginfo = set() + for out_file in [out2, out3]: + # this would fail if out2 were not overwritten with a valid sig zip + idx = sourmash.load_file_as_index(out_file) + sigs = list(idx.signatures()) + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + + # Assert that all expected signatures are found (ignoring order) + assert all_siginfo == expected_siginfo diff --git a/tests/test_urlsketch.py b/tests/test_urlsketch.py index 51160d2..a1d354c 100644 --- a/tests/test_urlsketch.py +++ b/tests/test_urlsketch.py @@ -481,3 +481,176 @@ def test_urlsketch_md5sum_mismatch_no_checksum_file(runtmp, capfd): assert md5sum == "b1234567" assert download_filename == "GCA_000175535.1_genomic.urlsketch.fna.gz" assert url == "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/175/535/GCA_000175535.1_ASM17553v1/GCA_000175535.1_ASM17553v1_genomic.fna.gz" + + +def test_urlsketch_simple_batched(runtmp, capfd): + acc_csv = get_test_data('acc-url.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + out3 = runtmp.output('simple.3.zip') + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss3 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) + assert os.path.exists(out3) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + + expected_siginfo = { + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + (ss2.name, ss2.md5sum(), ss2.minhash.moltype), + (ss3.name, ss3.md5sum(), ss3.minhash.moltype) + } + # Collect all signatures from the output zip files + all_sigs = [] + + for out_file in [out1, out2, out3]: + idx = sourmash.load_file_as_index(out_file) + sigs = list(idx.signatures()) + assert len(sigs) == 1 # We expect exactly 1 signature per batch + all_sigs.append(sigs[0]) + + loaded_signatures = {(sig.name, sig.md5sum(), sig.minhash.moltype) for sig in all_sigs} + assert loaded_signatures == expected_siginfo, f"Loaded sigs: {loaded_signatures}, expected: {expected_siginfo}" + + +def test_urlsketch_simple_batch_restart(runtmp, capfd): + acc_csv = get_test_data('acc-url.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + out3 = runtmp.output('simple.3.zip') + + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss3 = sourmash.load_one_signature(sig2, ksize=21) + ss4 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + # first, cat sig2 into an output file that will trick gbsketch into thinking it's a prior batch + runtmp.sourmash('sig', 'cat', sig2, '-o', out1) + assert os.path.exists(out1) + + runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) + assert os.path.exists(out3) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + + expected_siginfo = { + (ss2.name, ss2.md5sum(), ss2.minhash.moltype), + (ss2.name, ss3.md5sum(), ss3.minhash.moltype), # ss2 name b/c thats how it is in acc-url.csv + (ss4.name, ss4.md5sum(), ss4.minhash.moltype), + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + } + + all_siginfo = set() + for out_file in [out1, out2, out3]: + idx = sourmash.load_file_as_index(out_file) + sigs = list(idx.signatures()) + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + + # Verify that the loaded signatures match the expected signatures, order-independent + assert all_siginfo == expected_siginfo, f"Loaded sigs: {all_siginfo}, expected: {expected_siginfo}" + + +def test_urlsketch_negative_batch_size(runtmp): + # negative int provided for batch size + acc_csv = runtmp.output('acc1.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + + with pytest.raises(utils.SourmashCommandFailed): + runtmp.sourmash('scripts', 'urlsketch', acc_csv, + '--failed', failed, '-r', '1', '--batch-size', '-2', + '--param-str', "dna,k=31,scaled=1000") + + assert "Batch size cannot be negative (input value: -2)" in runtmp.last_result.err + + +def test_urlsketch_simple_batch_restart_with_incomplete_zip(runtmp, capfd): + # test restart with complete + incomplete zipfile batches + acc_csv = get_test_data('acc-url.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + out3 = runtmp.output('simple.3.zip') + + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss3 = sourmash.load_one_signature(sig2, ksize=21) + ss4 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + # first, cat sig2 into an output file that will trick gbsketch into thinking it's a prior batch + runtmp.sourmash('sig', 'cat', sig2, '-o', out1) + assert os.path.exists(out1) + + # Create an invalid zip file for out2 + with open(out2, 'wb') as f: + f.write(b"This is not a valid zip file!") + + runtmp.sourmash('scripts', 'urlsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) + assert os.path.exists(out3) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + assert f"Warning: Invalid zip file '{out2}'; skipping." in captured.err + + expected_siginfo = { + (ss2.name, ss2.md5sum(), ss2.minhash.moltype), + (ss2.name, ss3.md5sum(), ss3.minhash.moltype), # ss2 name b/c thats how it is in acc-url.csv + (ss4.name, ss4.md5sum(), ss4.minhash.moltype), + (ss1.name, ss1.md5sum(), ss1.minhash.moltype), + } + + all_siginfo = set() + for out_file in [out1, out2, out3]: + idx = sourmash.load_file_as_index(out_file) + sigs = list(idx.signatures()) + for sig in sigs: + all_siginfo.add((sig.name, sig.md5sum(), sig.minhash.moltype)) + + # Verify that the loaded signatures match the expected signatures, order-independent + assert all_siginfo == expected_siginfo, f"Loaded sigs: {all_siginfo}, expected: {expected_siginfo}"