From eeeabfab1def58ec945334527da0616092d672b6 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Mon, 4 Nov 2024 10:51:05 -0800 Subject: [PATCH] WIP: debug/fix fastgather scaled problem (#499) * debug gather scaled problem * cleanup & rationalize * more cleanup --- src/fastgather.rs | 15 ++++++---- src/lib.rs | 5 ++-- .../sourmash_plugin_branchwater/__init__.py | 4 +-- src/python/tests/test_fastgather.py | 29 +++++++++++++++++++ src/utils/mod.rs | 3 +- 5 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/fastgather.rs b/src/fastgather.rs index 025ffc16..2d1dee56 100644 --- a/src/fastgather.rs +++ b/src/fastgather.rs @@ -2,6 +2,7 @@ use anyhow::Result; use sourmash::prelude::Select; use sourmash::selection::Selection; +use sourmash::sketch::minhash::KmerMinHash; use crate::utils::{ consume_query_by_gather, load_collection, load_sketches_above_threshold, write_prefetch, @@ -13,7 +14,6 @@ pub fn fastgather( query_filepath: String, against_filepath: String, threshold_bp: usize, - scaled: usize, selection: Selection, gather_output: Option, prefetch_output: Option, @@ -40,24 +40,29 @@ pub fn fastgather( let query_md5 = query_sig.md5sum(); // clone here is necessary b/c we use full query_sig in consume_query_by_gather - let query_sig_ds = query_sig.select(&selection)?; // downsample - let query_mh = match query_sig_ds.try_into() { + let query_sig_ds = query_sig.select(&selection)?; // downsample as needed. + let query_mh: KmerMinHash = match query_sig_ds.try_into() { Ok(query_mh) => query_mh, Err(_) => { bail!("No query sketch matching selection parameters."); } }; + + let mut against_selection = selection; + let scaled = query_mh.scaled(); + against_selection.set_scaled(scaled as u32); + // load collection to match against. let against_collection = load_collection( &against_filepath, - &selection, + &against_selection, ReportType::Against, allow_failed_sigpaths, )?; // calculate the minimum number of hashes based on desired threshold let threshold_hashes: u64 = { - let x = threshold_bp / scaled; + let x = threshold_bp / scaled as usize; if x > 0 { x } else { diff --git a/src/lib.rs b/src/lib.rs index 35f887ad..dbe4850f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,19 +90,18 @@ fn do_fastgather( siglist_path: String, threshold_bp: usize, ksize: u8, - scaled: usize, + scaled: Option, moltype: String, output_path_prefetch: Option, output_path_gather: Option, ) -> anyhow::Result { - let selection = build_selection(ksize, Some(scaled), &moltype); + let selection = build_selection(ksize, scaled, &moltype); let allow_failed_sigpaths = true; match fastgather::fastgather( query_filename, siglist_path, threshold_bp, - scaled, selection, output_path_prefetch, output_path_gather, diff --git a/src/python/sourmash_plugin_branchwater/__init__.py b/src/python/sourmash_plugin_branchwater/__init__.py index c88f10b3..88011a7f 100755 --- a/src/python/sourmash_plugin_branchwater/__init__.py +++ b/src/python/sourmash_plugin_branchwater/__init__.py @@ -172,9 +172,9 @@ def __init__(self, p): p.add_argument( "-s", "--scaled", - default=1000, + default=None, type=int, - help="scaled factor at which to do comparisons (default: 1000)", + help="scaled factor at which to do comparisons (default: determined from query)", ) p.add_argument( "-m", diff --git a/src/python/tests/test_fastgather.py b/src/python/tests/test_fastgather.py index 99cacf82..97c59808 100644 --- a/src/python/tests/test_fastgather.py +++ b/src/python/tests/test_fastgather.py @@ -1325,3 +1325,32 @@ def test_fullres_vs_sourmash_gather(runtmp): fg_total_weighted_hashes = set(gather_df["total_weighted_hashes"]) g_total_weighted_hashes = set(sourmash_gather_df["total_weighted_hashes"]) assert fg_total_weighted_hashes == g_total_weighted_hashes == set([73489]) + + +def test_equal_matches(runtmp): + # check that equal matches get returned from fastgather + base = sourmash.MinHash(scaled=1, ksize=31, n=0) + + a = base.copy_and_clear() + b = base.copy_and_clear() + c = base.copy_and_clear() + + a.add_many(range(0, 1000)) + b.add_many(range(1000, 2000)) + c.add_many(range(0, 2000)) + + ss = sourmash.SourmashSignature(a, name='g_a') + sourmash.save_signatures([ss], open(runtmp.output('a.sig'), 'wb')) + ss = sourmash.SourmashSignature(b, name='g_b') + sourmash.save_signatures([ss], open(runtmp.output('b.sig'), 'wb')) + ss = sourmash.SourmashSignature(c, name='g_mg') + sourmash.save_signatures([ss], open(runtmp.output('mg.sig'), 'wb')) + + runtmp.sourmash('sig', 'cat', 'a.sig', 'b.sig', '-o', 'combined.sig.zip') + + runtmp.sourmash('scripts', 'fastgather', 'mg.sig', 'combined.sig.zip', + '-o', 'out.csv', '--threshold-bp', '0') + + df = pandas.read_csv(runtmp.output('out.csv')) + assert len(df) == 2 + assert set(df['intersect_bp']) == { 1000 } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 9791ae6e..4a9618d5 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -676,11 +676,10 @@ pub fn branchwater_calculate_gather_stats( calc_ani_ci: bool, confidence: Option, ) -> Result { - //bp remaining in subtracted query + // bp remaining in subtracted query let remaining_bp = (query.size() - match_size) * query.scaled() as usize; // stats for this match vs original query - let (intersect_orig, _) = match_mh.intersection_size(orig_query).unwrap(); let intersect_bp = (match_mh.scaled() * intersect_orig) as usize; let f_orig_query = intersect_orig as f64 / orig_query.size() as f64; let f_match_orig = intersect_orig as f64 / match_mh.size() as f64;