diff --git a/Cargo.lock b/Cargo.lock index 71111003..4d3fa303 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1596,7 +1596,7 @@ dependencies = [ [[package]] name = "sourmash_plugin_branchwater" -version = "0.9.8" +version = "0.9.9-dev" dependencies = [ "anyhow", "assert_cmd", diff --git a/Cargo.toml b/Cargo.toml index 0f86560e..85e3413e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sourmash_plugin_branchwater" -version = "0.9.8" +version = "0.9.9-dev" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/fastgather.rs b/src/fastgather.rs index 4feff7cf..025ffc16 100644 --- a/src/fastgather.rs +++ b/src/fastgather.rs @@ -14,14 +14,14 @@ pub fn fastgather( against_filepath: String, threshold_bp: usize, scaled: usize, - selection: &Selection, + selection: Selection, gather_output: Option, prefetch_output: Option, allow_failed_sigpaths: bool, ) -> Result<()> { let query_collection = load_collection( &query_filepath, - selection, + &selection, ReportType::Query, allow_failed_sigpaths, )?; @@ -40,7 +40,7 @@ pub fn fastgather( let query_md5 = query_sig.md5sum(); // clone here is necessary b/c we use full query_sig in consume_query_by_gather - let query_sig_ds = query_sig.select(selection)?; // downsample + let query_sig_ds = query_sig.select(&selection)?; // downsample let query_mh = match query_sig_ds.try_into() { Ok(query_mh) => query_mh, Err(_) => { @@ -50,7 +50,7 @@ pub fn fastgather( // load collection to match against. let against_collection = load_collection( &against_filepath, - selection, + &selection, ReportType::Against, allow_failed_sigpaths, )?; diff --git a/src/fastmultigather.rs b/src/fastmultigather.rs index 1fa37215..fa5295f8 100644 --- a/src/fastmultigather.rs +++ b/src/fastmultigather.rs @@ -29,8 +29,8 @@ pub fn fastmultigather( query_filepath: String, against_filepath: String, threshold_bp: usize, - scaled: usize, - selection: &Selection, + scaled: Option, + selection: Selection, allow_failed_sigpaths: bool, save_matches: bool, create_empty_results: bool, @@ -40,11 +40,26 @@ pub fn fastmultigather( // load query collection let query_collection = load_collection( &query_filepath, - selection, + &selection, ReportType::Query, allow_failed_sigpaths, )?; + let scaled = match scaled { + Some(s) => s, + None => { + let scaled = query_collection.max_scaled().expect("no records!?").clone() as usize; + eprintln!( + "Setting scaled={} based on max scaled in query collection", + scaled + ); + scaled + } + }; + + let mut against_selection = selection; + against_selection.set_scaled(scaled as u32); + let threshold_hashes: u64 = { let x = threshold_bp / scaled; if x > 0 { @@ -60,12 +75,12 @@ pub fn fastmultigather( // load against collection let against_collection = load_collection( &against_filepath, - selection, + &against_selection, ReportType::Against, allow_failed_sigpaths, )?; // load against sketches into memory, downsampling on the way - let against = against_collection.load_sketches(selection)?; + let against = against_collection.load_sketches(&against_selection)?; // Iterate over all queries => do prefetch and gather! let processed_queries = AtomicUsize::new(0); diff --git a/src/index.rs b/src/index.rs index c303b09b..102892fd 100644 --- a/src/index.rs +++ b/src/index.rs @@ -8,7 +8,7 @@ use sourmash::collection::{Collection, CollectionSet}; pub fn index>( siglist: String, - selection: &Selection, + selection: Selection, output: P, colors: bool, allow_failed_sigpaths: bool, @@ -18,7 +18,7 @@ pub fn index>( let multi = match load_collection( &siglist, - selection, + &selection, ReportType::General, allow_failed_sigpaths, ) { @@ -31,7 +31,7 @@ pub fn index>( let collection = match Collection::try_from(multi.clone()) { // conversion worked! Ok(c) => { - let cs: CollectionSet = c.select(selection)?.try_into()?; + let cs: CollectionSet = c.select(&selection)?.try_into()?; Ok(cs) } // conversion failed; can we/should we load it into memory? @@ -39,7 +39,7 @@ pub fn index>( if use_internal_storage { eprintln!("WARNING: loading all sketches into memory in order to index."); eprintln!("See 'index' documentation for details."); - let c: Collection = multi.load_all_sigs(selection)?; + let c: Collection = multi.load_all_sigs(&selection)?; let cs: CollectionSet = c.try_into()?; Ok(cs) } else { diff --git a/src/lib.rs b/src/lib.rs index d4bf33f5..68f869af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,7 +39,7 @@ fn do_manysearch( ignore_abundance: Option, ) -> anyhow::Result { let againstfile_path: PathBuf = siglist_path.clone().into(); - let selection = build_selection(ksize, scaled, &moltype); + let selection = build_selection(ksize, Some(scaled), &moltype); eprintln!("selection scaled: {:?}", selection.scaled()); let allow_failed_sigpaths = true; @@ -51,7 +51,7 @@ fn do_manysearch( match mastiff_manysearch::mastiff_manysearch( querylist_path, againstfile_path, - &selection, + selection, threshold, output_path, allow_failed_sigpaths, @@ -66,7 +66,7 @@ fn do_manysearch( match manysearch::manysearch( querylist_path, siglist_path, - &selection, + selection, threshold, output_path, allow_failed_sigpaths, @@ -94,7 +94,7 @@ fn do_fastgather( output_path_prefetch: Option, output_path_gather: Option, ) -> anyhow::Result { - let selection = build_selection(ksize, scaled, &moltype); + let selection = build_selection(ksize, Some(scaled), &moltype); let allow_failed_sigpaths = true; match fastgather::fastgather( @@ -102,7 +102,7 @@ fn do_fastgather( siglist_path, threshold_bp, scaled, - &selection, + selection, output_path_prefetch, output_path_gather, allow_failed_sigpaths, @@ -123,7 +123,7 @@ fn do_fastmultigather( siglist_path: String, threshold_bp: usize, ksize: u8, - scaled: usize, + scaled: Option, moltype: String, output_path: Option, save_matches: bool, @@ -138,7 +138,7 @@ fn do_fastmultigather( match mastiff_manygather::mastiff_manygather( query_filenames, againstfile_path, - &selection, + selection.clone(), threshold_bp, output_path, allow_failed_sigpaths, @@ -158,7 +158,7 @@ fn do_fastmultigather( siglist_path, threshold_bp, scaled, - &selection, + selection, allow_failed_sigpaths, save_matches, create_empty_results, @@ -199,11 +199,11 @@ fn do_index( colors: bool, use_internal_storage: bool, ) -> anyhow::Result { - let selection = build_selection(ksize, scaled, &moltype); + let selection = build_selection(ksize, Some(scaled), &moltype); let allow_failed_sigpaths = false; match index::index( siglist, - &selection, + selection, output, colors, allow_failed_sigpaths, @@ -237,7 +237,7 @@ fn do_multisearch( siglist_path: String, threshold: f64, ksize: u8, - scaled: usize, + scaled: Option, moltype: String, estimate_ani: bool, output_path: Option, @@ -251,7 +251,7 @@ fn do_multisearch( querylist_path, siglist_path, threshold, - &selection, + selection, allow_failed_sigpaths, estimate_ani, output_path, @@ -277,12 +277,12 @@ fn do_pairwise( write_all: bool, output_path: Option, ) -> anyhow::Result { - let selection = build_selection(ksize, scaled, &moltype); + let selection = build_selection(ksize, Some(scaled), &moltype); let allow_failed_sigpaths = true; match pairwise::pairwise( siglist_path, threshold, - &selection, + selection, allow_failed_sigpaths, estimate_ani, write_all, diff --git a/src/manysearch.rs b/src/manysearch.rs index 1b77cb16..4c539267 100644 --- a/src/manysearch.rs +++ b/src/manysearch.rs @@ -19,7 +19,7 @@ use sourmash::sketch::minhash::KmerMinHash; pub fn manysearch( query_filepath: String, against_filepath: String, - selection: &Selection, + selection: Selection, threshold: f64, output: Option, allow_failed_sigpaths: bool, @@ -28,18 +28,18 @@ pub fn manysearch( // Load query collection let query_collection = load_collection( &query_filepath, - selection, + &selection, ReportType::Query, allow_failed_sigpaths, )?; // load all query sketches into memory, downsampling on the way - let query_sketchlist = query_collection.load_sketches(selection)?; + let query_sketchlist = query_collection.load_sketches(&selection)?; // Against: Load collection, potentially off disk & not into memory. let against_collection = load_collection( &against_filepath, - selection, + &selection, ReportType::Against, allow_failed_sigpaths, )?; diff --git a/src/mastiff_manygather.rs b/src/mastiff_manygather.rs index 6839d250..4d15f696 100644 --- a/src/mastiff_manygather.rs +++ b/src/mastiff_manygather.rs @@ -15,7 +15,7 @@ use crate::utils::{ pub fn mastiff_manygather( queries_file: String, index: PathBuf, - selection: &Selection, + selection: Selection, threshold_bp: usize, output: Option, allow_failed_sigpaths: bool, @@ -29,7 +29,7 @@ pub fn mastiff_manygather( let query_collection = load_collection( &queries_file, - selection, + &selection, ReportType::Query, allow_failed_sigpaths, )?; diff --git a/src/mastiff_manysearch.rs b/src/mastiff_manysearch.rs index 158dded1..ba0d7559 100644 --- a/src/mastiff_manysearch.rs +++ b/src/mastiff_manysearch.rs @@ -18,7 +18,7 @@ use crate::utils::{ pub fn mastiff_manysearch( queries_path: String, index: PathBuf, - selection: &Selection, + selection: Selection, minimum_containment: f64, output: Option, allow_failed_sigpaths: bool, @@ -35,7 +35,7 @@ pub fn mastiff_manysearch( // Load query paths let query_collection = load_collection( &queries_path, - selection, + &selection, ReportType::Query, allow_failed_sigpaths, )?; diff --git a/src/multisearch.rs b/src/multisearch.rs index cf4bacb8..f668dca0 100644 --- a/src/multisearch.rs +++ b/src/multisearch.rs @@ -18,7 +18,7 @@ pub fn multisearch( query_filepath: String, against_filepath: String, threshold: f64, - selection: &Selection, + selection: Selection, allow_failed_sigpaths: bool, estimate_ani: bool, output: Option, @@ -26,22 +26,39 @@ pub fn multisearch( // Load all queries into memory at once. let query_collection = load_collection( &query_filepath, - selection, + &selection, ReportType::Query, allow_failed_sigpaths, )?; - let queries = query_collection.load_sketches(selection)?; + let scaled = match selection.scaled() { + Some(s) => s, + None => { + let scaled = query_collection.max_scaled().expect("no records!?").clone() as u32; + eprintln!( + "Setting scaled={} based on max scaled in query collection", + scaled + ); + scaled + } + }; + + let ksize = selection.ksize().unwrap() as f64; + + let mut new_selection = selection; + new_selection.set_scaled(scaled as u32); + + let queries = query_collection.load_sketches(&new_selection)?; // Load all against sketches into memory at once. let against_collection = load_collection( &against_filepath, - selection, + &new_selection, ReportType::Against, allow_failed_sigpaths, )?; - let against = against_collection.load_sketches(selection)?; + let against = against_collection.load_sketches(&new_selection)?; // set up a multi-producer, single-consumer channel. let (send, recv) = @@ -57,7 +74,6 @@ pub fn multisearch( // let processed_cmp = AtomicUsize::new(0); - let ksize = selection.ksize().unwrap() as f64; let send = against .par_iter() @@ -70,7 +86,11 @@ pub fn multisearch( eprintln!("Processed {} comparisons", i); } - let overlap = query.minhash.count_common(&against.minhash, false).unwrap() as f64; + let overlap = query + .minhash + .count_common(&against.minhash, false) + .expect("cannot compare query and against!?") + as f64; // use downsampled sizes let query_size = query.minhash.size() as f64; let target_size = against.minhash.size() as f64; diff --git a/src/pairwise.rs b/src/pairwise.rs index 914c44f3..f3e447f1 100644 --- a/src/pairwise.rs +++ b/src/pairwise.rs @@ -16,7 +16,7 @@ use sourmash::signature::SigsTrait; pub fn pairwise( siglist: String, threshold: f64, - selection: &Selection, + selection: Selection, allow_failed_sigpaths: bool, estimate_ani: bool, write_all: bool, @@ -25,7 +25,7 @@ pub fn pairwise( // Load all sigs into memory at once. let collection = load_collection( &siglist, - selection, + &selection, ReportType::General, allow_failed_sigpaths, )?; @@ -37,7 +37,7 @@ pub fn pairwise( ) } - let sketches = collection.load_sketches(selection)?; + let sketches = collection.load_sketches(&selection)?; // set up a multi-producer, single-consumer channel. let (send, recv) = diff --git a/src/python/sourmash_plugin_branchwater/__init__.py b/src/python/sourmash_plugin_branchwater/__init__.py index 116d2072..58b80788 100755 --- a/src/python/sourmash_plugin_branchwater/__init__.py +++ b/src/python/sourmash_plugin_branchwater/__init__.py @@ -249,9 +249,9 @@ def __init__(self, p): p.add_argument( "-s", "--scaled", - default=1000, + default=None, type=int, - help="scaled factor at which to do comparisons (default: 1000)", + help="scaled factor at which to do comparisons (default: determined from query collection)", ) p.add_argument( "-m", @@ -435,9 +435,9 @@ def __init__(self, p): p.add_argument( "-s", "--scaled", - default=1000, + default=None, type=int, - help="scaled factor at which to do comparisons", + help="scaled factor at which to do comparisons (default: determined from query collection)", ) p.add_argument( "-m", diff --git a/src/python/tests/test_fastmultigather.py b/src/python/tests/test_fastmultigather.py index a21de63e..522cf17e 100644 --- a/src/python/tests/test_fastmultigather.py +++ b/src/python/tests/test_fastmultigather.py @@ -1987,3 +1987,100 @@ def test_create_empty_results(runtmp): g_output = runtmp.output("CP001071.1.gather.csv") p_output = runtmp.output("CP001071.1.prefetch.csv") assert os.path.exists(p_output) + + +def test_simple_against_scaled(runtmp, zip_against): + # we shouldn't automatically downsample query + query = get_test_data("SRR606249.sig.gz") + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + downsampled_sigs = runtmp.output("ds.sig.zip") + runtmp.sourmash( + "sig", + "downsample", + "--scaled", + "120_000", + sig2, + sig47, + sig63, + "-o", + downsampled_sigs, + ) + + query_list = runtmp.output("query.txt") + make_file_list(query_list, [query]) + + with pytest.raises(utils.SourmashCommandFailed): + runtmp.sourmash( + "scripts", + "fastmultigather", + query_list, + downsampled_sigs, + "-t", + "0", + in_directory=runtmp.output(""), + ) + + +def test_simple_query_scaled(runtmp): + # test basic execution w/automatic scaled selection based on query + query = get_test_data("SRR606249.sig.gz") + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + query_list = runtmp.output("query.txt") + against_list = runtmp.output("against.txt") + + make_file_list(query_list, [query]) + make_file_list(against_list, [sig2, sig47, sig63]) + + runtmp.sourmash( + "scripts", + "fastmultigather", + query_list, + against_list, + "-t", + "0", + in_directory=runtmp.output(""), + ) + + print(os.listdir(runtmp.output(""))) + + g_output = runtmp.output("SRR606249.gather.csv") + assert os.path.exists(g_output) + + +def test_simple_query_scaled_indexed(runtmp): + # test basic execution w/automatic scaled selection based on query + # (on a rocksdb) + query = get_test_data("SRR606249.sig.gz") + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + query_list = runtmp.output("query.txt") + against_list = runtmp.output("against.txt") + + make_file_list(query_list, [query]) + make_file_list(against_list, [sig2, sig47, sig63]) + against_list = index_siglist(runtmp, against_list, runtmp.output("against.rocksdb")) + + runtmp.sourmash( + "scripts", + "fastmultigather", + query_list, + against_list, + "-o", + "foo.csv", + "-t", + "0", + in_directory=runtmp.output(""), + ) + + print(os.listdir(runtmp.output(""))) + + g_output = runtmp.output("foo.csv") + assert os.path.exists(g_output) diff --git a/src/python/tests/test_multisearch.py b/src/python/tests/test_multisearch.py index 43b6715a..ef018a96 100644 --- a/src/python/tests/test_multisearch.py +++ b/src/python/tests/test_multisearch.py @@ -1176,3 +1176,48 @@ def test_simple_below_threshold(runtmp): assert float(row["match_containment_ani"]) == 1.0 assert float(row["average_containment_ani"]) == 1.0 assert float(row["max_containment_ani"]) == 1.0 + + +def test_mismatched_scaled_query(runtmp): + # test what happens if query scaled is too high + query_list = runtmp.output("query.txt") + against_list = runtmp.output("against.txt") + + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + query_list = runtmp.output("downsample.sig.zip") + runtmp.sourmash( + "sig", "downsample", "--scaled=10_000", sig2, sig47, sig63, "-o", query_list + ) + make_file_list(against_list, [sig2, sig47, sig63]) + + output = runtmp.output("out.csv") + + runtmp.sourmash("scripts", "multisearch", query_list, against_list, "-o", output) + assert os.path.exists(output) + + +def test_mismatched_scaled_against(runtmp): + # test what happens if against scaled is too high + query_list = runtmp.output("query.txt") + against_list = runtmp.output("against.txt") + + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + make_file_list(query_list, [sig2, sig47, sig63]) + + against_list = runtmp.output("downsample.sig.zip") + runtmp.sourmash( + "sig", "downsample", "--scaled=10_000", sig2, sig47, sig63, "-o", against_list + ) + + output = runtmp.output("out.csv") + + with pytest.raises(utils.SourmashCommandFailed): + runtmp.sourmash( + "scripts", "multisearch", query_list, against_list, "-o", output + ) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d5285820..a29b794e 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -580,6 +580,7 @@ pub fn load_collection( match collection { Some((coll, n_failed)) => { let n_total = coll.len(); + let selected = coll.select(selection)?; let n_skipped = n_total - selected.len(); report_on_collection_loading( @@ -955,7 +956,7 @@ pub fn consume_query_by_gather( Ok(()) } -pub fn build_selection(ksize: u8, scaled: usize, moltype: &str) -> Selection { +pub fn build_selection(ksize: u8, scaled: Option, moltype: &str) -> Selection { let hash_function = match moltype { "DNA" => HashFunctions::Murmur64Dna, "protein" => HashFunctions::Murmur64Protein, @@ -967,11 +968,18 @@ pub fn build_selection(ksize: u8, scaled: usize, moltype: &str) -> Selection { // .map_err(|_| panic!("Unknown molecule type: {}", moltype)) // .unwrap(); - Selection::builder() - .ksize(ksize.into()) - .scaled(scaled as u32) - .moltype(hash_function) - .build() + if let Some(scaled) = scaled { + Selection::builder() + .ksize(ksize.into()) + .scaled(scaled as u32) + .moltype(hash_function) + .build() + } else { + Selection::builder() + .ksize(ksize.into()) + .moltype(hash_function) + .build() + } } pub fn is_revindex_database(path: &camino::Utf8PathBuf) -> bool { diff --git a/src/utils/multicollection.rs b/src/utils/multicollection.rs index 82a351ed..03de24cb 100644 --- a/src/utils/multicollection.rs +++ b/src/utils/multicollection.rs @@ -275,6 +275,10 @@ impl MultiCollection { val == 0 } + pub fn max_scaled(&self) -> Option<&u64> { + self.item_iter().map(|(_, _, record)| record.scaled()).max() + } + // iterate over tuples pub fn item_iter(&self) -> impl Iterator { let s: Vec<_> = self