Skip to content

Commit

Permalink
rm zip batching logic
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Sep 30, 2024
1 parent 7684e30 commit 3df61df
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 98 deletions.
93 changes: 11 additions & 82 deletions src/directsketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,33 +408,17 @@ async fn dl_sketch_url(
/// create zip file depending on batch size and index.
async fn create_or_get_zip_file(
outpath: &PathBuf,
batch_size: usize,
batch_index: usize,
) -> Result<ZipFileWriter<Compat<File>>, anyhow::Error> {
let batch_outpath = if batch_size == 0 {
// If batch size is zero, use provided outpath (contains .zip extension)
outpath.clone()
} else {
// Otherwise, modify outpath to include the batch index
let outpath_base = outpath.with_extension(""); // remove .zip extension
outpath_base.with_file_name(format!(
"{}.{}.zip",
outpath_base.file_stem().unwrap(),
batch_index
))
};

let file = File::create(&batch_outpath)
let file = File::create(&outpath)
.await
.with_context(|| format!("Failed to create file: {:?}", batch_outpath))?;
.with_context(|| format!("Failed to create file: {:?}", outpath))?;

Ok(ZipFileWriter::with_tokio(file))
}

pub fn zipwriter_handle(
mut recv_sigs: tokio::sync::mpsc::Receiver<BuildCollection>,
output_sigs: Option<String>,
batch_size: usize, // Tunable batch size
error_sender: tokio::sync::mpsc::Sender<anyhow::Error>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
Expand All @@ -445,17 +429,15 @@ pub fn zipwriter_handle(

if let Some(outpath) = output_sigs {
let outpath: PathBuf = outpath.into();
let mut batch_index = 1; // index to name zip files

// Create the initial zip file
let mut zip_writer =
match create_or_get_zip_file(&outpath, batch_size, batch_index).await {
Ok(writer) => writer,
Err(e) => {
let _ = error_sender.send(e).await;
return;
}
};
let mut zip_writer = match create_or_get_zip_file(&outpath).await {
Ok(writer) => writer,
Err(e) => {
let _ = error_sender.send(e).await;
return;
}
};

while let Some(mut sigcoll) = recv_sigs.recv().await {
// write all sigs from sigcoll. Note that this method updates each record's internal location
Expand All @@ -477,35 +459,6 @@ pub fn zipwriter_handle(
// add all records from sigcoll manifest
zip_manifest.extend_from_manifest(&sigcoll.manifest);
file_count += sigcoll.size();

// If batch size is non-zero and is reached, close the current ZIP and start a new one
if batch_size > 0 && file_count >= batch_size {
if let Err(e) = zip_manifest
.async_write_manifest_to_zip(&mut zip_writer)
.await
{
let _ = error_sender.send(e).await;
}

if let Err(e) = zip_writer.close().await {
let error = anyhow::Error::new(e).context("Failed to close ZIP file");
let _ = error_sender.send(error).await;
return;
}

// start a new batch
batch_index += 1;
file_count = 0;
zip_manifest.clear();
zip_writer =
match create_or_get_zip_file(&outpath, batch_size, batch_index).await {
Ok(writer) => writer,
Err(e) => {
let _ = error_sender.send(e).await;
return;
}
};
}
}

if file_count > 0 {
Expand Down Expand Up @@ -536,19 +489,6 @@ pub fn zipwriter_handle(
})
}

/// to do: instead, read from zipfiles (manifests make this easier!)
pub async fn signature_from_path(path: &PathBuf) -> Result<Vec<Signature>, Error> {
let path = path.clone(); // Clone the path to move into the blocking thread

// Use `spawn_blocking` to call `Signature::from_path` (synchronous)
let sigs = tokio::task::spawn_blocking(move || {
Signature::from_path(&path).map_err(|e| anyhow!("Error reading signatures: {}", e))
})
.await??;

Ok(sigs)
}

pub fn failures_handle(
failed_csv: String,
mut recv_failed: tokio::sync::mpsc::Receiver<FailedDownload>,
Expand Down Expand Up @@ -639,10 +579,8 @@ pub async fn gbsketch(
proteomes_only: bool,
download_only: bool,
output_sigs: Option<String>,
batch_size: u32,
) -> Result<(), anyhow::Error> {
// if sig output provided but doesn't end in zip, bail
let batch_size = batch_size as usize;
if let Some(ref output_sigs) = output_sigs {
if Path::new(&output_sigs)
.extension()
Expand All @@ -662,15 +600,10 @@ pub async fn gbsketch(
// Error channel for handling task errors
let (error_sender, error_receiver) = tokio::sync::mpsc::channel::<anyhow::Error>(1);

// Initialize an optional Manifest to hold existing signatures
// let mut existing_sigs: Option<Manifest> = None;

// to do --> read from existing sig zips, build filename: params_set hashmap

// Set up collector/writing tasks
let mut handles = Vec::new();

let sig_handle = zipwriter_handle(recv_sigs, output_sigs, batch_size, error_sender.clone());
let sig_handle = zipwriter_handle(recv_sigs, output_sigs, error_sender.clone());

let failures_handle = failures_handle(failed_csv, recv_failed, error_sender.clone());
let critical_error_flag = Arc::new(AtomicBool::new(false));
Expand Down Expand Up @@ -742,8 +675,6 @@ pub async fn gbsketch(
let send_errors = error_sender.clone();
let mut dna_sigs = dna_template_collection.clone();
let mut prot_sigs = prot_template_collection.clone();
// clone existing sig manifest
// let e_siginfo = existing_sigs.clone();

tokio::spawn(async move {
let _permit = semaphore_clone.acquire().await;
Expand Down Expand Up @@ -823,10 +754,8 @@ pub async fn urlsketch(
fasta_location: String,
keep_fastas: bool,
download_only: bool,
batch_size: u32,
output_sigs: Option<String>,
) -> Result<(), anyhow::Error> {
let batch_size = batch_size as usize;
// if sig output provided but doesn't end in zip, bail
if let Some(ref output_sigs) = output_sigs {
if Path::new(&output_sigs)
Expand All @@ -851,7 +780,7 @@ pub async fn urlsketch(
// Set up collector/writing tasks
let mut handles = Vec::new();

let sig_handle = zipwriter_handle(recv_sigs, output_sigs, batch_size, error_sender.clone());
let sig_handle = zipwriter_handle(recv_sigs, output_sigs, error_sender.clone());

let failures_handle = failures_handle(failed_csv, recv_failed, error_sender.clone());
let critical_error_flag = Arc::new(AtomicBool::new(false));
Expand Down
8 changes: 2 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn set_tokio_thread_pool(num_threads: usize) -> PyResult<usize> {

#[pyfunction]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (input_csv, param_str, failed_csv, retry_times, fasta_location, keep_fastas, genomes_only, proteomes_only, download_only, batch_size, output_sigs=None))]
#[pyo3(signature = (input_csv, param_str, failed_csv, retry_times, fasta_location, keep_fastas, genomes_only, proteomes_only, download_only, output_sigs=None))]
fn do_gbsketch(
py: Python,
input_csv: String,
Expand All @@ -61,7 +61,6 @@ fn do_gbsketch(
genomes_only: bool,
proteomes_only: bool,
download_only: bool,
batch_size: u32,
output_sigs: Option<String>,
) -> anyhow::Result<u8> {
match directsketch::gbsketch(
Expand All @@ -76,7 +75,6 @@ fn do_gbsketch(
proteomes_only,
download_only,
output_sigs,
batch_size,
) {
Ok(_) => Ok(0),
Err(e) => {
Expand All @@ -88,7 +86,7 @@ fn do_gbsketch(

#[pyfunction]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (input_csv, param_str, failed_csv, retry_times, fasta_location, keep_fastas, download_only, batch_size, output_sigs=None))]
#[pyo3(signature = (input_csv, param_str, failed_csv, retry_times, fasta_location, keep_fastas, download_only, output_sigs=None))]
fn do_urlsketch(
py: Python,
input_csv: String,
Expand All @@ -98,7 +96,6 @@ fn do_urlsketch(
fasta_location: String,
keep_fastas: bool,
download_only: bool,
batch_size: u32,
output_sigs: Option<String>,
) -> anyhow::Result<u8> {
match directsketch::urlsketch(
Expand All @@ -110,7 +107,6 @@ fn do_urlsketch(
fasta_location,
keep_fastas,
download_only,
batch_size,
output_sigs,
) {
Ok(_) => Ok(0),
Expand Down
10 changes: 0 additions & 10 deletions src/python/sourmash_plugin_directsketch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ def __init__(self, p):
help='output zip file for the signatures')
p.add_argument('-f', '--fastas',
help='Write fastas here', default = '.')
p.add_argument('--batch-size', type=int, default = 0,
help='Write smaller zipfiles, each containing approximately this number of files. \
This allows gbsketch to recover after unexpected failures, rather than needing to \
restart sketching from scratch.')
p.add_argument('-k', '--keep-fasta', action='store_true',
help="write FASTA files in addition to sketching. Default: do not write FASTA files")
p.add_argument('--download-only', help='just download genomes; do not sketch', action='store_true')
Expand Down Expand Up @@ -94,7 +90,6 @@ def main(self, args):
args.genomes_only,
args.proteomes_only,
args.download_only,
args.batch_size,
args.output)

if status == 0:
Expand All @@ -116,10 +111,6 @@ def __init__(self, p):
p.add_argument('input_csv', help="a txt file or csv file containing accessions in the first column")
p.add_argument('-o', '--output', default=None,
help='output zip file for the signatures')
p.add_argument('--batch-size', type=int, default = 0,
help='Write smaller zipfiles, each containing approximately this number of files. \
This allows urlsketch to recover after unexpected failures, rather than needing to \
restart sketching from scratch.')
p.add_argument('-f', '--fastas',
help='Write fastas here', default = '.')
p.add_argument('-k', '--keep-fasta', '--keep-fastq', action='store_true',
Expand Down Expand Up @@ -164,7 +155,6 @@ def main(self, args):
args.fastas,
args.keep_fasta,
args.download_only,
args.batch_size,
args.output)

if status == 0:
Expand Down

0 comments on commit 3df61df

Please sign in to comment.