diff --git a/Cargo.lock b/Cargo.lock index 0c7b0fe..8ff714e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1815,7 +1815,7 @@ dependencies = [ [[package]] name = "sourmash_plugin_directsketch" -version = "0.2.1" +version = "0.2.2" dependencies = [ "anyhow", "async_zip", diff --git a/Cargo.toml b/Cargo.toml index 835f0ec..e061267 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sourmash_plugin_directsketch" -version = "0.2.1" +version = "0.2.2" edition = "2021" [lib] diff --git a/src/directsketch.rs b/src/directsketch.rs index 743d316..9ec1f82 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -10,6 +10,7 @@ use std::collections::HashMap; use std::fs::{self, create_dir_all}; use std::io::Cursor; use std::path::Path; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::fs::File; use tokio::io::{AsyncWriteExt, BufWriter}; @@ -581,6 +582,21 @@ pub fn failures_handle( }) } +pub fn error_handler( + mut recv_errors: tokio::sync::mpsc::Receiver, + error_flag: Arc, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + while let Some(error) = recv_errors.recv().await { + eprintln!("Error: {}", error); + if error.to_string().contains("No signatures written") { + error_flag.store(true, Ordering::SeqCst); + break; + } + } + }) +} + #[tokio::main] #[allow(clippy::too_many_arguments)] pub async fn download_and_sketch( @@ -611,17 +627,20 @@ pub async fn download_and_sketch( // // create channels. buffer size can be changed - here it is 4 b/c we can do 3 downloads simultaneously // // to do: see whether increasing buffer size speeds things up - let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::>(4); - let (send_failed, recv_failed) = tokio::sync::mpsc::channel::(4); + let (send_sigs, recv_sigs) = tokio::sync::mpsc::channel::>(1000); + let (send_failed, recv_failed) = tokio::sync::mpsc::channel::(100); // // Error channel for handling task errors - let (error_sender, mut error_receiver) = tokio::sync::mpsc::channel::(1); + let (error_sender, error_receiver) = tokio::sync::mpsc::channel::(1); // // // Set up collector/writing tasks let mut handles = Vec::new(); let sig_handle = sigwriter_handle(recv_sigs, output_sigs, error_sender.clone()); let failures_handle = failures_handle(failed_csv, recv_failed, error_sender.clone()); + let critical_error_flag = Arc::new(AtomicBool::new(false)); + let error_handle = error_handler(error_receiver, critical_error_flag.clone()); handles.push(sig_handle); handles.push(failures_handle); + handles.push(error_handle); // // Worker tasks let semaphore = Arc::new(Semaphore::new(3)); // Limiting concurrent downloads @@ -713,16 +732,14 @@ pub async fn download_and_sketch( // Wait for all tasks to complete for handle in handles { if let Err(e) = handle.await { - eprintln!("A task encountered an error: {}", e); + eprintln!("Handle join error: {}.", e); } } - // // Handle errors received from the error channel - while let Some(error) = error_receiver.recv().await { - eprintln!("Error: {}", error); - // Check if the error message contains "No signatures written" - if error.to_string().contains("No signatures written") & !download_only { - bail!("{}.", error); - } + // since the only critical error is not having written any sigs + // check this here at end. Bail if we wrote expected sigs but wrote none. + if critical_error_flag.load(Ordering::SeqCst) & !download_only { + bail!("No signatures written, exiting."); } + Ok(()) } diff --git a/tests/test_gbsketch.py b/tests/test_gbsketch.py index 5b398e6..0e12224 100644 --- a/tests/test_gbsketch.py +++ b/tests/test_gbsketch.py @@ -292,4 +292,4 @@ def test_gbsketch_bad_acc_fail(runtmp, capfd): captured = capfd.readouterr() print(captured.out) print(captured.err) - assert "Error: No signatures written." in captured.err \ No newline at end of file + assert "Error: No signatures written, exiting." in captured.err \ No newline at end of file