From ed9ec2eb28685c2cfb414ba7947d01023a0e59f7 Mon Sep 17 00:00:00 2001 From: "N. Tessa Pierce-Ward" Date: Tue, 30 Apr 2024 14:37:01 -0700 Subject: [PATCH] add --genomes-only and --proteomes-only --- src/directsketch.rs | 43 ++++++++++----- src/lib.rs | 4 ++ .../sourmash_plugin_directsketch/__init__.py | 7 ++- tests/test_gbsketch.py | 54 +++++++++++++++++++ 4 files changed, 93 insertions(+), 15 deletions(-) diff --git a/src/directsketch.rs b/src/directsketch.rs index 1f92946..c994cad 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -192,6 +192,8 @@ async fn dl_sketch_accession( keep_fastas: bool, dna_sigs: Vec, prot_sigs: Vec, + genomes_only: bool, + proteomes_only: bool, ) -> Result<(Vec, Vec)> { let retry_count = retry.unwrap_or(3); // Default retry count let mut sigs = Vec::::new(); @@ -202,29 +204,38 @@ async fn dl_sketch_accession( Ok(result) => result, Err(_err) => { // Add accession to failed downloads with each moltype - let failed_download_dna = FailedDownload { - accession: accession.clone(), - url: "".to_string(), - moltype: "dna".to_string(), - }; - let failed_download_protein = FailedDownload { - accession: accession.clone(), - url: "".to_string(), - moltype: "protein".to_string(), - }; - failed.push(failed_download_dna); - failed.push(failed_download_protein); + if !proteomes_only { + let failed_download_dna = FailedDownload { + accession: accession.clone(), + url: "".to_string(), + moltype: "dna".to_string(), + }; + failed.push(failed_download_dna); + } + if !genomes_only { + let failed_download_protein = FailedDownload { + accession: accession.clone(), + url: "".to_string(), + moltype: "protein".to_string(), + }; + failed.push(failed_download_protein); + } + return Ok((sigs, failed)); } }; - // Combine all file types into a single vector - let file_types = vec![ + let mut file_types = vec![ GenBankFileType::Genomic, GenBankFileType::Protein, // GenBankFileType::AssemblyReport, // GenBankFileType::Checksum, // Including standalone files like checksums here ]; + if genomes_only { + file_types = vec![GenBankFileType::Genomic]; + } else if proteomes_only { + file_types = vec![GenBankFileType::Protein]; + } for file_type in &file_types { let url = file_type.url(&base_url, &full_name); @@ -332,6 +343,8 @@ pub async fn download_and_sketch( retry_times: u32, fasta_location: String, keep_fastas: bool, + genomes_only: bool, + proteomes_only: bool, ) -> Result<(), anyhow::Error> { let download_path = PathBuf::from(fasta_location); if !download_path.exists() { @@ -406,6 +419,8 @@ pub async fn download_and_sketch( keep_fastas, dna_sig_templates.clone(), prot_sig_templates.clone(), + genomes_only, + proteomes_only, ) .await; diff --git a/src/lib.rs b/src/lib.rs index 1e4316b..7318c6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,6 +33,8 @@ fn do_gbsketch( retry_times: u32, fasta_location: String, keep_fastas: bool, + genomes_only: bool, + proteomes_only: bool, ) -> anyhow::Result { // let runtime = tokio::runtime::Runtime::new().unwrap(); @@ -48,6 +50,8 @@ fn do_gbsketch( retry_times, fasta_location, keep_fastas, + genomes_only, + proteomes_only, ) { Ok(_) => Ok(0), Err(e) => { diff --git a/src/python/sourmash_plugin_directsketch/__init__.py b/src/python/sourmash_plugin_directsketch/__init__.py index 5b72501..6ee00ba 100644 --- a/src/python/sourmash_plugin_directsketch/__init__.py +++ b/src/python/sourmash_plugin_directsketch/__init__.py @@ -51,6 +51,9 @@ def __init__(self, p): help='number of cores to use (default is all available)') p.add_argument('-r', '--retry-times', default=1, type=int, help='number of times to retry failed downloads') + group = p.add_mutually_exclusive_group() + group.add_argument('-g', '--genomes-only', action='store_true', help='just download and sketch genome (DNA) files') + group.add_argument('-m', '--proteomes-only', action='store_true', help='just download and sketch proteome (protein) files') def main(self, args): print_version() @@ -74,7 +77,9 @@ def main(self, args): args.output, args.retry_times, args.fastas, - args.keep_fastas) + args.keep_fastas, + args.genomes_only, + args.proteomes_only) if status == 0: notify(f"...gbsketch is done! Sigs in '{args.output}'. Fastas in '{args.fastas}'.") diff --git a/tests/test_gbsketch.py b/tests/test_gbsketch.py index 3554dd8..dcd624f 100644 --- a/tests/test_gbsketch.py +++ b/tests/test_gbsketch.py @@ -56,6 +56,60 @@ def test_gbsketch_simple(runtmp): else: assert sig.md5sum() == ss3.md5sum() +def test_gbsketch_genomes_only(runtmp): + acc_csv = get_test_data('acc.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + + sig1 = get_test_data('GCA_000175555.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + + runtmp.sourmash('scripts', 'gbsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--genomes-only', + '--param-str', "dna,k=31,scaled=1000", '-p', "protein,k=10,scaled=200") + + assert os.path.exists(output) + assert not runtmp.last_result.out # stdout should be empty + + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + + assert len(sigs) == 2 + for sig in sigs: + if 'GCA_000175555.1' in sig.name: + assert sig.name == ss1.name + assert sig.md5sum() == ss1.md5sum() + elif 'GCA_000961135.2' in sig.name: + assert sig.name == ss2.name + assert sig.md5sum() == ss2.md5sum() + + +def test_gbsketch_proteomes_only(runtmp): + acc_csv = get_test_data('acc.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + # why does this need ksize =30 and not ksize = 10!??? + ss3 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + runtmp.sourmash('scripts', 'gbsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--proteomes-only', + '--param-str', "dna,k=31,scaled=1000", '-p', "protein,k=10,scaled=200") + + assert os.path.exists(output) + assert not runtmp.last_result.out # stdout should be empty + + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + + assert len(sigs) == 1 + for sig in sigs: + assert 'GCA_000961135.2' in sig.name + assert sig.md5sum() == ss3.md5sum() + def test_gbsketch_save_fastas(runtmp): acc_csv = get_test_data('acc.csv')