Skip to content

Commit

Permalink
Add probability of overlap and weighted containment for Multisearch m…
Browse files Browse the repository at this point in the history
…atches (#458)

* Add probability of overlap and weighted containment to multisearch result

* Start writing prob_overlap

* Couldn't figure out how to get prob_overlap.rs to import .. putting into utils.rs for now

* Trying to get prob overlap to at least import properly

* Start writing a merge_all_minhashes function

* Write in commented code what needs to happen

* Remove mut from unused variables for now

* wrote function to merge all minhashes of a vector of signatures

* Added mege_all_minhashes to multisearch

* Add crates for stable calculation of log values

* Add dependencies for stable calculation of log values in Cargo.lock

* Add rust decimal with math feature

* Add function to get probability of overlap between specific intersection hashes of all queries and all database minhash

* Call probability of overlap between queries and database

* I'm getting too confused by rust_decimal .. let's go back to using the standard library

* Add adjusted prob_overlap to MultiSearchResult

* Getting prob_overlap to actually work

* Add failing test for test_multisearch.py

* Fix n_comparisons to be float, remove commented out pseudocode

* Remove unnecessary parens

* Added prob_overlap, prob_overlap_adjusted, containment_adjusted, containment_adjusted_log10 values to test_multisearch

* Add print statements

* Add containment_adjusted_log10

* Fix compiler errors

* Fix rounding for prob_overlap, prob_overlap_adjusted, containment_adjusted, containment_adjusted_log10

* Move probability of overlap code into separate search_significance module

* add tf_idf_score to test_multisearch.py

* Add tf_idf_score to MultiSearchResult

* Make separate "againsts" as Vec<Sig>

* Get TF-IDF running

* remove print statements and commented out code

* Remove print statements, commented out code, add todos

* Fix optional boolean types for prob_overlap and tf idf

* Add multisearch test of protein with abundance

* Remove part_001 from signature filename

* Delete old part_001 file

* Remove too big sig from test data

* Add test of probability of overlap with multisearch

* Add --prob argument

* Precompute frequencies for queries and againsts, save as HashMaps for fast lookups

* Use L2 norm for tf idf, add more print messages

* Use par_iter whenever possible

* Remove logsumexp from files

* Add failing test to make sure prob_overlap only gets computed when --prob-overlap specified
'

* Remove logsumexp from rust file

* Try to make prob_overlap calculation optional

* Make prob_overlap an optional column

* remove unused and commented out code

* add comment for estimate_prob_overlap

* Remove `let` keyword to stop "shadowing" the variables

* add par_bridge() after iter_mins() for parallel computation

* Remove `let` from creating precomupted HashMaps for search significance and TF-IDF

* Remove checking for non-existence of prob_overlap when it really should be there

* remove unsed 'mut'

* Add float_round function

* Fix missing bracket

* Rename unused hashval variable -> _hashval

* Update protein fasta paths in test_sketch.py ... but also run black formatting

* Add comment about minhash not being defined

* remove commented out code

* Add clarification about squaring 1

* Apply `cargo fix --lib -p sourmash_plugin_branchwater`

* Remove unused import

* Just kidding, that import was used

* Fix SmallSignature import

* Fix weirdness for test_simple_ani and test_simple_prob_overlap caused by merge conflicts

* Run black and fix zip True/False in test_against_multisigfile

* whitespace

* formatting

* "syn" package appeared twice

* Trailing whitespace

* Add protein k5 signature

* Apply black formatting to everytthing

* Merge black-applied python test files

* Missed some merge markers

* Missed more merge markers...

* Fix black in test_multisearch.py

* Remove commented out code

* unwrap -> expect

* Modularize the probability of overlap computation into functions

* set values for prob_overlap results in the if statement

* Add longer argument name and description

* Cargo fmt

* Borrow 'selection'

* Clone selection

* Add longer argument name

* Use `new_selection` to set scaled

* Add @pytest.mark.xfail(reason="should work, bug") to `test_fastgather.py:test_against_multisigfile`

* Revert test_against_multisigfile back to main

* Remove .clone() from selection
  • Loading branch information
olgabot authored Nov 12, 2024
1 parent c5f5866 commit 0dd65d6
Show file tree
Hide file tree
Showing 11 changed files with 1,054 additions and 93 deletions.
371 changes: 347 additions & 24 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ camino = "1.1.9"
glob = "0.3.1"
rustworkx-core = "0.15.1"
streaming-stats = "0.2.3"
rust_decimal = { version = "1.36.0", features = ["maths"] }
rust_decimal_macros = "1.36.0"

[dev-dependencies]
assert_cmd = "2.0.16"
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod manysearch_rocksdb;
mod manysketch;
mod multisearch;
mod pairwise;
mod search_significance;
mod singlesketch;

use camino::Utf8PathBuf as PathBuf;
Expand Down Expand Up @@ -231,7 +232,7 @@ fn do_check(index: String, quick: bool) -> anyhow::Result<u8> {
}

#[pyfunction]
#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, estimate_ani, output_path=None))]
#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, estimate_ani, estimate_prob_overlap, output_path=None))]
#[allow(clippy::too_many_arguments)]
fn do_multisearch(
querylist_path: String,
Expand All @@ -241,6 +242,7 @@ fn do_multisearch(
scaled: Option<u32>,
moltype: String,
estimate_ani: bool,
estimate_prob_overlap: bool,
output_path: Option<String>,
) -> anyhow::Result<u8> {
let _ = env_logger::try_init();
Expand All @@ -255,6 +257,7 @@ fn do_multisearch(
selection,
allow_failed_sigpaths,
estimate_ani,
estimate_prob_overlap,
output_path,
) {
Ok(_) => Ok(0),
Expand Down
176 changes: 173 additions & 3 deletions src/multisearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,134 @@ use anyhow::Result;
use rayon::prelude::*;
use sourmash::selection::Selection;
use sourmash::signature::SigsTrait;
use sourmash::sketch::minhash::KmerMinHash;
use std::collections::HashMap;
use std::sync::atomic;
use std::sync::atomic::AtomicUsize;

use crate::search_significance::{
compute_inverse_document_frequency, get_hash_frequencies, get_prob_overlap,
get_term_frequency_inverse_document_frequency, merge_all_minhashes, Normalization,
};
use crate::utils::multicollection::SmallSignature;
use crate::utils::{csvwriter_thread, load_collection, MultiSearchResult, ReportType};
use sourmash::ani_utils::ani_from_containment;

#[derive(Default, Clone, Debug)]
struct ProbOverlapStats {
prob_overlap: f64,
prob_overlap_adjusted: f64,
containment_adjusted: f64,
containment_adjusted_log10: f64,
tf_idf_score: f64,
}

/// Computes probability overlap statistics for a single pair of signatures
fn compute_single_prob_overlap(
query: &SmallSignature,
against: &SmallSignature,
n_comparisons: f64,
query_merged_frequencies: &HashMap<u64, f64>,
against_merged_frequencies: &HashMap<u64, f64>,
query_term_frequencies: &HashMap<String, HashMap<u64, f64>>,
inverse_document_frequency: &HashMap<u64, f64>,
containment_query_in_target: f64,
) -> ProbOverlapStats {
let overlapping_hashvals: Vec<u64> = query
.minhash
.intersection(&against.minhash)
.expect("Intersection of query and against minhashes")
.0;

let prob_overlap = get_prob_overlap(
&overlapping_hashvals,
query_merged_frequencies,
against_merged_frequencies,
);

let prob_overlap_adjusted = prob_overlap * n_comparisons;
let containment_adjusted = containment_query_in_target / prob_overlap_adjusted;

ProbOverlapStats {
prob_overlap,
prob_overlap_adjusted,
containment_adjusted,
containment_adjusted_log10: containment_adjusted.log10(),
tf_idf_score: get_term_frequency_inverse_document_frequency(
&overlapping_hashvals,
&query_term_frequencies[&query.md5sum],
inverse_document_frequency,
),
}
}

/// Computes probability overlap statistics for queries and against signatures
/// Estimate probability of overlap between query sig and against sig, using
/// underlying distribution of hashvals for all queries and all againsts
fn compute_prob_overlap_stats(
queries: &Vec<SmallSignature>,
againsts: &Vec<SmallSignature>,
) -> (
f64,
HashMap<u64, f64>,
HashMap<u64, f64>,
HashMap<String, HashMap<u64, f64>>,
HashMap<u64, f64>,
) {
let n_comparisons = againsts.len() as f64 * queries.len() as f64;

// Combine all the queries and against into a single signature each
eprintln!("Merging queries ...");
let queries_merged_mh: KmerMinHash =
merge_all_minhashes(queries).expect("Merging query minhashes");
eprintln!("\tDone.\n");

eprintln!("Merging against ...");
let against_merged_mh: KmerMinHash =
merge_all_minhashes(againsts).expect("Merging against minhashes");
eprintln!("\tDone.\n");

// Compute IDF
eprintln!("Computing Inverse Document Frequency (IDF) of hashes in all againsts ...");
let inverse_document_frequency =
compute_inverse_document_frequency(&against_merged_mh, againsts, Some(true));
eprintln!("\tDone.\n");

// Compute frequencies
eprintln!("Computing frequency of hashvals across all againsts (L1 Norm) ...");
let against_merged_frequencies =
get_hash_frequencies(&against_merged_mh, Some(Normalization::L1));
eprintln!("\tDone.\n");

eprintln!("Computing frequency of hashvals across all queries (L1 Norm) ...");
let query_merged_frequencies =
get_hash_frequencies(&queries_merged_mh, Some(Normalization::L1));
eprintln!("\tDone.\n");

// Compute term frequencies
eprintln!("Computing hashval term frequencies within each query (L2 Norm) ...");
let query_term_frequencies = HashMap::from(
queries
.par_iter()
.map(|query| {
(
query.md5sum.clone(),
get_hash_frequencies(&query.minhash, Some(Normalization::L2)),
)
})
.collect::<HashMap<String, HashMap<u64, f64>>>(),
);
eprintln!("\tDone.\n");

(
n_comparisons,
query_merged_frequencies,
against_merged_frequencies,
query_term_frequencies,
inverse_document_frequency,
)
}

/// Search many queries against a list of signatures.
///
/// Note: this function loads all _queries_ into memory, and iterates over
Expand All @@ -21,6 +143,7 @@ pub fn multisearch(
selection: Selection,
allow_failed_sigpaths: bool,
estimate_ani: bool,
estimate_prob_overlap: bool,
output: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
// Load all queries into memory at once.
Expand Down Expand Up @@ -48,7 +171,7 @@ pub fn multisearch(
let mut new_selection = selection;
new_selection.set_scaled(expected_scaled);

let queries = query_collection.load_sketches(&new_selection)?;
let queries: Vec<SmallSignature> = query_collection.load_sketches(&new_selection)?;

// Load all against sketches into memory at once.
let against_collection = load_collection(
Expand All @@ -58,7 +181,25 @@ pub fn multisearch(
allow_failed_sigpaths,
)?;

let against = against_collection.load_sketches(&new_selection)?;
let againsts: Vec<SmallSignature> = against_collection.load_sketches(&new_selection)?;

let (
n_comparisons,
query_merged_frequencies,
against_merged_frequencies,
query_term_frequencies,
inverse_document_frequency,
) = if estimate_prob_overlap {
compute_prob_overlap_stats(&queries, &againsts)
} else {
(
0.0,
Default::default(),
Default::default(),
Default::default(),
Default::default(),
)
};

// set up a multi-producer, single-consumer channel.
let (send, recv) =
Expand All @@ -75,7 +216,7 @@ pub fn multisearch(

let processed_cmp = AtomicUsize::new(0);

let send = against
let send = againsts
.par_iter()
.filter_map(|against| {
let mut results = vec![];
Expand Down Expand Up @@ -115,6 +256,30 @@ pub fn multisearch(
let mut match_containment_ani = None;
let mut average_containment_ani = None;
let mut max_containment_ani = None;
let mut prob_overlap: Option<f64> = None;
let mut prob_overlap_adjusted: Option<f64> = None;
let mut containment_adjusted: Option<f64> = None;
let mut containment_adjusted_log10: Option<f64> = None;
let mut tf_idf_score: Option<f64> = None;

// Compute probability overlap stats if requested
if estimate_prob_overlap {
let prob_stats = compute_single_prob_overlap(
query,
against,
n_comparisons,
&query_merged_frequencies,
&against_merged_frequencies,
&query_term_frequencies,
&inverse_document_frequency,
containment_query_in_target,
);
prob_overlap = Some(prob_stats.prob_overlap);
prob_overlap_adjusted = Some(prob_stats.prob_overlap_adjusted);
containment_adjusted = Some(prob_stats.containment_adjusted);
containment_adjusted_log10 = Some(prob_stats.containment_adjusted_log10);
tf_idf_score = Some(prob_stats.tf_idf_score);
}

// estimate ANI values
if estimate_ani {
Expand Down Expand Up @@ -142,6 +307,11 @@ pub fn multisearch(
match_containment_ani,
average_containment_ani,
max_containment_ani,
prob_overlap: prob_overlap,
prob_overlap_adjusted: prob_overlap_adjusted,
containment_adjusted: containment_adjusted,
containment_adjusted_log10: containment_adjusted_log10,
tf_idf_score: tf_idf_score,
})
}
}
Expand Down
21 changes: 21 additions & 0 deletions src/pairwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ pub fn pairwise(
let containment_q1_in_q2 = overlap / query1_size;
let containment_q2_in_q1 = overlap / query2_size;

let prob_overlap = None;
let prob_overlap_adjusted = None;
let containment_adjusted = None;
let containment_adjusted_log10 = None;
let tf_idf_score = None;

if containment_q1_in_q2 > threshold || containment_q2_in_q1 > threshold {
let max_containment = containment_q1_in_q2.max(containment_q2_in_q1);
let jaccard = overlap / (query1_size + query2_size - overlap);
Expand Down Expand Up @@ -113,6 +119,11 @@ pub fn pairwise(
match_containment_ani,
average_containment_ani,
max_containment_ani,
prob_overlap,
prob_overlap_adjusted,
containment_adjusted,
containment_adjusted_log10,
tf_idf_score,
})
.unwrap();
}
Expand All @@ -127,6 +138,11 @@ pub fn pairwise(
let mut match_containment_ani = None;
let mut average_containment_ani = None;
let mut max_containment_ani = None;
let prob_overlap = None;
let prob_overlap_adjusted = None;
let containment_adjusted = None;
let containment_adjusted_log10 = None;
let tf_idf_score = None;

if estimate_ani {
query_containment_ani = Some(1.0);
Expand All @@ -151,6 +167,11 @@ pub fn pairwise(
match_containment_ani,
average_containment_ani,
max_containment_ani,
prob_overlap,
prob_overlap_adjusted,
containment_adjusted,
containment_adjusted_log10,
tf_idf_score,
})
.unwrap();
}
Expand Down
10 changes: 10 additions & 0 deletions src/python/sourmash_plugin_branchwater/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ def __init__(self, p):
p.add_argument(
"-a", "--ani", action="store_true", help="estimate ANI from containment"
)
p.add_argument(
"-p",
"--prob-significant-overlap",
action="store_true",
help="estimate probability of overlap for significance ranking of search results, of the specific query and match, given all queries and possible matches",
)

def main(self, args):
print_version()
Expand All @@ -468,6 +474,9 @@ def main(self, args):
notify(
f"searching all sketches in '{args.query_paths}' against '{args.against_paths}' using {num_threads} threads"
)
notify(
f"estimate ani? {args.ani} / estimate probability of overlap? {args.prob_significant_overlap}"
)

super().main(args)
status = sourmash_plugin_branchwater.do_multisearch(
Expand All @@ -478,6 +487,7 @@ def main(self, args):
args.scaled,
args.moltype,
args.ani,
args.prob_significant_overlap,
args.output,
)
if status == 0:
Expand Down
Loading

0 comments on commit 0dd65d6

Please sign in to comment.