From 0dd65d6f2c1fbe803c3f8465a5f4ec4f2c0c1dc2 Mon Sep 17 00:00:00 2001 From: Olga Botvinnik Date: Tue, 12 Nov 2024 13:29:33 -0800 Subject: [PATCH] Add probability of overlap and weighted containment for Multisearch matches (#458) * Add probability of overlap and weighted containment to multisearch result * Start writing prob_overlap * Couldn't figure out how to get prob_overlap.rs to import .. putting into utils.rs for now * Trying to get prob overlap to at least import properly * Start writing a merge_all_minhashes function * Write in commented code what needs to happen * Remove mut from unused variables for now * wrote function to merge all minhashes of a vector of signatures * Added mege_all_minhashes to multisearch * Add crates for stable calculation of log values * Add dependencies for stable calculation of log values in Cargo.lock * Add rust decimal with math feature * Add function to get probability of overlap between specific intersection hashes of all queries and all database minhash * Call probability of overlap between queries and database * I'm getting too confused by rust_decimal .. let's go back to using the standard library * Add adjusted prob_overlap to MultiSearchResult * Getting prob_overlap to actually work * Add failing test for test_multisearch.py * Fix n_comparisons to be float, remove commented out pseudocode * Remove unnecessary parens * Added prob_overlap, prob_overlap_adjusted, containment_adjusted, containment_adjusted_log10 values to test_multisearch * Add print statements * Add containment_adjusted_log10 * Fix compiler errors * Fix rounding for prob_overlap, prob_overlap_adjusted, containment_adjusted, containment_adjusted_log10 * Move probability of overlap code into separate search_significance module * add tf_idf_score to test_multisearch.py * Add tf_idf_score to MultiSearchResult * Make separate "againsts" as Vec * Get TF-IDF running * remove print statements and commented out code * Remove print statements, commented out code, add todos * Fix optional boolean types for prob_overlap and tf idf * Add multisearch test of protein with abundance * Remove part_001 from signature filename * Delete old part_001 file * Remove too big sig from test data * Add test of probability of overlap with multisearch * Add --prob argument * Precompute frequencies for queries and againsts, save as HashMaps for fast lookups * Use L2 norm for tf idf, add more print messages * Use par_iter whenever possible * Remove logsumexp from files * Add failing test to make sure prob_overlap only gets computed when --prob-overlap specified ' * Remove logsumexp from rust file * Try to make prob_overlap calculation optional * Make prob_overlap an optional column * remove unused and commented out code * add comment for estimate_prob_overlap * Remove `let` keyword to stop "shadowing" the variables * add par_bridge() after iter_mins() for parallel computation * Remove `let` from creating precomupted HashMaps for search significance and TF-IDF * Remove checking for non-existence of prob_overlap when it really should be there * remove unsed 'mut' * Add float_round function * Fix missing bracket * Rename unused hashval variable -> _hashval * Update protein fasta paths in test_sketch.py ... but also run black formatting * Add comment about minhash not being defined * remove commented out code * Add clarification about squaring 1 * Apply `cargo fix --lib -p sourmash_plugin_branchwater` * Remove unused import * Just kidding, that import was used * Fix SmallSignature import * Fix weirdness for test_simple_ani and test_simple_prob_overlap caused by merge conflicts * Run black and fix zip True/False in test_against_multisigfile * whitespace * formatting * "syn" package appeared twice * Trailing whitespace * Add protein k5 signature * Apply black formatting to everytthing * Merge black-applied python test files * Missed some merge markers * Missed more merge markers... * Fix black in test_multisearch.py * Remove commented out code * unwrap -> expect * Modularize the probability of overlap computation into functions * set values for prob_overlap results in the if statement * Add longer argument name and description * Cargo fmt * Borrow 'selection' * Clone selection * Add longer argument name * Use `new_selection` to set scaled * Add @pytest.mark.xfail(reason="should work, bug") to `test_fastgather.py:test_against_multisigfile` * Revert test_against_multisigfile back to main * Remove .clone() from selection --- Cargo.lock | 371 ++++++++++++++++-- Cargo.toml | 2 + src/lib.rs | 5 +- src/multisearch.rs | 176 ++++++++- src/pairwise.rs | 21 + .../sourmash_plugin_branchwater/__init__.py | 10 + .../tests/test-data/snap25.protein.k5.sig | 1 + src/python/tests/test_multisearch.py | 331 +++++++++++++--- src/python/tests/test_sketch.py | 5 +- src/search_significance.rs | 210 ++++++++++ src/utils/mod.rs | 15 + 11 files changed, 1054 insertions(+), 93 deletions(-) create mode 100644 src/python/tests/test-data/snap25.protein.k5.sig create mode 100644 src/search_significance.rs diff --git a/Cargo.lock b/Cargo.lock index 449522b5..72ed890a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "ahash" version = "0.8.11" @@ -130,6 +141,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "assert_cmd" version = "2.0.16" @@ -198,7 +215,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn", + "syn 2.0.85", ] [[package]] @@ -213,6 +230,42 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "borsh" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6362ed55def622cddc70a4746a68554d7b687713770de539e59a739b249f8ed" +dependencies = [ + "borsh-derive", + "cfg_aliases", +] + +[[package]] +name = "borsh-derive" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3ef8005764f53cd4dca619f5bf64cafd4664dada50ece25e4d81de54c80cc0b" +dependencies = [ + "once_cell", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.85", + "syn_derive", +] + [[package]] name = "bstr" version = "1.9.1" @@ -239,6 +292,28 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytecheck" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23cdc57ce23ac53c931e88a43d06d070a6fd142f2617be5855eb75efc9beb1c2" +dependencies = [ + "bytecheck_derive", + "ptr_meta", + "simdutf8", +] + +[[package]] +name = "bytecheck_derive" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3db406d29fbcd95542e92559bed4d8ad92636d1ca8b3b72ede10b4bcc010e659" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "bytecount" version = "0.6.8" @@ -257,6 +332,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" + [[package]] name = "bzip2" version = "0.4.4" @@ -313,6 +394,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.38" @@ -431,7 +518,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -448,7 +535,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -472,7 +559,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -545,6 +632,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "getrandom" version = "0.2.15" @@ -567,7 +660,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -576,13 +669,22 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +dependencies = [ + "ahash 0.7.8", +] + [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ - "ahash", + "ahash 0.8.11", "allocator-api2", "rayon", ] @@ -639,12 +741,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.5", "rayon", ] @@ -901,7 +1003,7 @@ checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -1057,7 +1159,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -1153,7 +1255,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.85", ] [[package]] @@ -1176,6 +1278,38 @@ dependencies = [ "indexmap", ] +[[package]] +name = "proc-macro-crate" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +dependencies = [ + "toml_edit", +] + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -1195,7 +1329,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -1215,11 +1349,31 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", "version_check", "yansi", ] +[[package]] +name = "ptr_meta" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0738ccf7ea06b608c10564b31debd4f5bc5e197fc8bfe088f68ae5ce81e7a4f1" +dependencies = [ + "ptr_meta_derive", +] + +[[package]] +name = "ptr_meta_derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "pyo3" version = "0.22.6" @@ -1268,7 +1422,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -1281,7 +1435,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -1293,6 +1447,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.8.5" @@ -1408,6 +1568,44 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +[[package]] +name = "rend" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71fe3824f5629716b1589be05dacd749f6aa084c87e00e016714a8cdfccc997c" +dependencies = [ + "bytecheck", +] + +[[package]] +name = "rkyv" +version = "0.7.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9008cd6385b9e161d8229e1f6549dd23c3d022f132a2ea37ac3a10ac4935779b" +dependencies = [ + "bitvec", + "bytecheck", + "bytes", + "hashbrown 0.12.3", + "ptr_meta", + "rend", + "rkyv_derive", + "seahash", + "tinyvec", + "uuid", +] + +[[package]] +name = "rkyv_derive" +version = "0.7.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "503d1d27590a2b0a3a4ca4c94755aa2875657196ecbf401a42eff41d7de532c0" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "roaring" version = "0.10.6" @@ -1434,6 +1632,32 @@ version = "0.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "082f11ffa03bbef6c2c6ea6bea1acafaade2fd9050ae0234ab44a2153742b058" +[[package]] +name = "rust_decimal" +version = "1.36.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b082d80e3e3cc52b2ed634388d436fe1f4de6af5786cc2de9ba9737527bdf555" +dependencies = [ + "arrayvec", + "borsh", + "bytes", + "num-traits", + "rand", + "rkyv", + "serde", + "serde_json", +] + +[[package]] +name = "rust_decimal_macros" +version = "1.36.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da991f231869f34268415a49724c6578e740ad697ba0999199d6f22b3949332c" +dependencies = [ + "quote", + "rust_decimal", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -1459,9 +1683,9 @@ version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef8108bdaf5b590d2ea261c6ca9b1795cbf253d0733b2e209b7990c95ed23843" dependencies = [ - "ahash", + "ahash 0.8.11", "fixedbitset", - "hashbrown", + "hashbrown 0.14.5", "indexmap", "ndarray", "num-traits", @@ -1488,6 +1712,12 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "seahash" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" + [[package]] name = "serde" version = "1.0.214" @@ -1505,7 +1735,7 @@ checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -1539,6 +1769,12 @@ dependencies = [ "wide", ] +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "simple-error" version = "0.3.1" @@ -1621,6 +1857,8 @@ dependencies = [ "predicates", "pyo3", "rayon", + "rust_decimal", + "rust_decimal_macros", "rustworkx-core", "serde", "serde_json", @@ -1658,6 +1896,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.85" @@ -1669,6 +1918,24 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn_derive" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1329189c02ff984e9736652b1631330da25eaa6bc639089ed4915d25446cbe7b" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.85", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "target-lexicon" version = "0.12.14" @@ -1711,7 +1978,39 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", +] + +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow", ] [[package]] @@ -1742,7 +2041,7 @@ checksum = "1f718dfaf347dcb5b983bfc87608144b0bad87970aebcbea5ce44d2a30c08e63" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] @@ -1769,6 +2068,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" + [[package]] name = "vcpkg" version = "0.2.15" @@ -1833,7 +2138,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.85", "wasm-bindgen-shared", ] @@ -1855,7 +2160,7 @@ checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1977,6 +2282,24 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +dependencies = [ + "memchr", +] + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "xz2" version = "0.1.7" @@ -2009,7 +2332,7 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.85", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 90475560..667c8369 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,8 @@ camino = "1.1.9" glob = "0.3.1" rustworkx-core = "0.15.1" streaming-stats = "0.2.3" +rust_decimal = { version = "1.36.0", features = ["maths"] } +rust_decimal_macros = "1.36.0" [dev-dependencies] assert_cmd = "2.0.16" diff --git a/src/lib.rs b/src/lib.rs index 604b91da..741b4980 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ mod manysearch_rocksdb; mod manysketch; mod multisearch; mod pairwise; +mod search_significance; mod singlesketch; use camino::Utf8PathBuf as PathBuf; @@ -231,7 +232,7 @@ fn do_check(index: String, quick: bool) -> anyhow::Result { } #[pyfunction] -#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, estimate_ani, output_path=None))] +#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, estimate_ani, estimate_prob_overlap, output_path=None))] #[allow(clippy::too_many_arguments)] fn do_multisearch( querylist_path: String, @@ -241,6 +242,7 @@ fn do_multisearch( scaled: Option, moltype: String, estimate_ani: bool, + estimate_prob_overlap: bool, output_path: Option, ) -> anyhow::Result { let _ = env_logger::try_init(); @@ -255,6 +257,7 @@ fn do_multisearch( selection, allow_failed_sigpaths, estimate_ani, + estimate_prob_overlap, output_path, ) { Ok(_) => Ok(0), diff --git a/src/multisearch.rs b/src/multisearch.rs index 27eeab35..f0befb16 100644 --- a/src/multisearch.rs +++ b/src/multisearch.rs @@ -3,12 +3,134 @@ use anyhow::Result; use rayon::prelude::*; use sourmash::selection::Selection; use sourmash::signature::SigsTrait; +use sourmash::sketch::minhash::KmerMinHash; +use std::collections::HashMap; use std::sync::atomic; use std::sync::atomic::AtomicUsize; +use crate::search_significance::{ + compute_inverse_document_frequency, get_hash_frequencies, get_prob_overlap, + get_term_frequency_inverse_document_frequency, merge_all_minhashes, Normalization, +}; +use crate::utils::multicollection::SmallSignature; use crate::utils::{csvwriter_thread, load_collection, MultiSearchResult, ReportType}; use sourmash::ani_utils::ani_from_containment; +#[derive(Default, Clone, Debug)] +struct ProbOverlapStats { + prob_overlap: f64, + prob_overlap_adjusted: f64, + containment_adjusted: f64, + containment_adjusted_log10: f64, + tf_idf_score: f64, +} + +/// Computes probability overlap statistics for a single pair of signatures +fn compute_single_prob_overlap( + query: &SmallSignature, + against: &SmallSignature, + n_comparisons: f64, + query_merged_frequencies: &HashMap, + against_merged_frequencies: &HashMap, + query_term_frequencies: &HashMap>, + inverse_document_frequency: &HashMap, + containment_query_in_target: f64, +) -> ProbOverlapStats { + let overlapping_hashvals: Vec = query + .minhash + .intersection(&against.minhash) + .expect("Intersection of query and against minhashes") + .0; + + let prob_overlap = get_prob_overlap( + &overlapping_hashvals, + query_merged_frequencies, + against_merged_frequencies, + ); + + let prob_overlap_adjusted = prob_overlap * n_comparisons; + let containment_adjusted = containment_query_in_target / prob_overlap_adjusted; + + ProbOverlapStats { + prob_overlap, + prob_overlap_adjusted, + containment_adjusted, + containment_adjusted_log10: containment_adjusted.log10(), + tf_idf_score: get_term_frequency_inverse_document_frequency( + &overlapping_hashvals, + &query_term_frequencies[&query.md5sum], + inverse_document_frequency, + ), + } +} + +/// Computes probability overlap statistics for queries and against signatures +/// Estimate probability of overlap between query sig and against sig, using +/// underlying distribution of hashvals for all queries and all againsts +fn compute_prob_overlap_stats( + queries: &Vec, + againsts: &Vec, +) -> ( + f64, + HashMap, + HashMap, + HashMap>, + HashMap, +) { + let n_comparisons = againsts.len() as f64 * queries.len() as f64; + + // Combine all the queries and against into a single signature each + eprintln!("Merging queries ..."); + let queries_merged_mh: KmerMinHash = + merge_all_minhashes(queries).expect("Merging query minhashes"); + eprintln!("\tDone.\n"); + + eprintln!("Merging against ..."); + let against_merged_mh: KmerMinHash = + merge_all_minhashes(againsts).expect("Merging against minhashes"); + eprintln!("\tDone.\n"); + + // Compute IDF + eprintln!("Computing Inverse Document Frequency (IDF) of hashes in all againsts ..."); + let inverse_document_frequency = + compute_inverse_document_frequency(&against_merged_mh, againsts, Some(true)); + eprintln!("\tDone.\n"); + + // Compute frequencies + eprintln!("Computing frequency of hashvals across all againsts (L1 Norm) ..."); + let against_merged_frequencies = + get_hash_frequencies(&against_merged_mh, Some(Normalization::L1)); + eprintln!("\tDone.\n"); + + eprintln!("Computing frequency of hashvals across all queries (L1 Norm) ..."); + let query_merged_frequencies = + get_hash_frequencies(&queries_merged_mh, Some(Normalization::L1)); + eprintln!("\tDone.\n"); + + // Compute term frequencies + eprintln!("Computing hashval term frequencies within each query (L2 Norm) ..."); + let query_term_frequencies = HashMap::from( + queries + .par_iter() + .map(|query| { + ( + query.md5sum.clone(), + get_hash_frequencies(&query.minhash, Some(Normalization::L2)), + ) + }) + .collect::>>(), + ); + eprintln!("\tDone.\n"); + + ( + n_comparisons, + query_merged_frequencies, + against_merged_frequencies, + query_term_frequencies, + inverse_document_frequency, + ) +} + /// Search many queries against a list of signatures. /// /// Note: this function loads all _queries_ into memory, and iterates over @@ -21,6 +143,7 @@ pub fn multisearch( selection: Selection, allow_failed_sigpaths: bool, estimate_ani: bool, + estimate_prob_overlap: bool, output: Option, ) -> Result<(), Box> { // Load all queries into memory at once. @@ -48,7 +171,7 @@ pub fn multisearch( let mut new_selection = selection; new_selection.set_scaled(expected_scaled); - let queries = query_collection.load_sketches(&new_selection)?; + let queries: Vec = query_collection.load_sketches(&new_selection)?; // Load all against sketches into memory at once. let against_collection = load_collection( @@ -58,7 +181,25 @@ pub fn multisearch( allow_failed_sigpaths, )?; - let against = against_collection.load_sketches(&new_selection)?; + let againsts: Vec = against_collection.load_sketches(&new_selection)?; + + let ( + n_comparisons, + query_merged_frequencies, + against_merged_frequencies, + query_term_frequencies, + inverse_document_frequency, + ) = if estimate_prob_overlap { + compute_prob_overlap_stats(&queries, &againsts) + } else { + ( + 0.0, + Default::default(), + Default::default(), + Default::default(), + Default::default(), + ) + }; // set up a multi-producer, single-consumer channel. let (send, recv) = @@ -75,7 +216,7 @@ pub fn multisearch( let processed_cmp = AtomicUsize::new(0); - let send = against + let send = againsts .par_iter() .filter_map(|against| { let mut results = vec![]; @@ -115,6 +256,30 @@ pub fn multisearch( let mut match_containment_ani = None; let mut average_containment_ani = None; let mut max_containment_ani = None; + let mut prob_overlap: Option = None; + let mut prob_overlap_adjusted: Option = None; + let mut containment_adjusted: Option = None; + let mut containment_adjusted_log10: Option = None; + let mut tf_idf_score: Option = None; + + // Compute probability overlap stats if requested + if estimate_prob_overlap { + let prob_stats = compute_single_prob_overlap( + query, + against, + n_comparisons, + &query_merged_frequencies, + &against_merged_frequencies, + &query_term_frequencies, + &inverse_document_frequency, + containment_query_in_target, + ); + prob_overlap = Some(prob_stats.prob_overlap); + prob_overlap_adjusted = Some(prob_stats.prob_overlap_adjusted); + containment_adjusted = Some(prob_stats.containment_adjusted); + containment_adjusted_log10 = Some(prob_stats.containment_adjusted_log10); + tf_idf_score = Some(prob_stats.tf_idf_score); + } // estimate ANI values if estimate_ani { @@ -142,6 +307,11 @@ pub fn multisearch( match_containment_ani, average_containment_ani, max_containment_ani, + prob_overlap: prob_overlap, + prob_overlap_adjusted: prob_overlap_adjusted, + containment_adjusted: containment_adjusted, + containment_adjusted_log10: containment_adjusted_log10, + tf_idf_score: tf_idf_score, }) } } diff --git a/src/pairwise.rs b/src/pairwise.rs index 4bcaf979..ae1f0c44 100644 --- a/src/pairwise.rs +++ b/src/pairwise.rs @@ -80,6 +80,12 @@ pub fn pairwise( let containment_q1_in_q2 = overlap / query1_size; let containment_q2_in_q1 = overlap / query2_size; + let prob_overlap = None; + let prob_overlap_adjusted = None; + let containment_adjusted = None; + let containment_adjusted_log10 = None; + let tf_idf_score = None; + if containment_q1_in_q2 > threshold || containment_q2_in_q1 > threshold { let max_containment = containment_q1_in_q2.max(containment_q2_in_q1); let jaccard = overlap / (query1_size + query2_size - overlap); @@ -113,6 +119,11 @@ pub fn pairwise( match_containment_ani, average_containment_ani, max_containment_ani, + prob_overlap, + prob_overlap_adjusted, + containment_adjusted, + containment_adjusted_log10, + tf_idf_score, }) .unwrap(); } @@ -127,6 +138,11 @@ pub fn pairwise( let mut match_containment_ani = None; let mut average_containment_ani = None; let mut max_containment_ani = None; + let prob_overlap = None; + let prob_overlap_adjusted = None; + let containment_adjusted = None; + let containment_adjusted_log10 = None; + let tf_idf_score = None; if estimate_ani { query_containment_ani = Some(1.0); @@ -151,6 +167,11 @@ pub fn pairwise( match_containment_ani, average_containment_ani, max_containment_ani, + prob_overlap, + prob_overlap_adjusted, + containment_adjusted, + containment_adjusted_log10, + tf_idf_score, }) .unwrap(); } diff --git a/src/python/sourmash_plugin_branchwater/__init__.py b/src/python/sourmash_plugin_branchwater/__init__.py index 10d025d1..597f165f 100755 --- a/src/python/sourmash_plugin_branchwater/__init__.py +++ b/src/python/sourmash_plugin_branchwater/__init__.py @@ -456,6 +456,12 @@ def __init__(self, p): p.add_argument( "-a", "--ani", action="store_true", help="estimate ANI from containment" ) + p.add_argument( + "-p", + "--prob-significant-overlap", + action="store_true", + help="estimate probability of overlap for significance ranking of search results, of the specific query and match, given all queries and possible matches", + ) def main(self, args): print_version() @@ -468,6 +474,9 @@ def main(self, args): notify( f"searching all sketches in '{args.query_paths}' against '{args.against_paths}' using {num_threads} threads" ) + notify( + f"estimate ani? {args.ani} / estimate probability of overlap? {args.prob_significant_overlap}" + ) super().main(args) status = sourmash_plugin_branchwater.do_multisearch( @@ -478,6 +487,7 @@ def main(self, args): args.scaled, args.moltype, args.ani, + args.prob_significant_overlap, args.output, ) if status == 0: diff --git a/src/python/tests/test-data/snap25.protein.k5.sig b/src/python/tests/test-data/snap25.protein.k5.sig new file mode 100644 index 00000000..7b3e2b15 --- /dev/null +++ b/src/python/tests/test-data/snap25.protein.k5.sig @@ -0,0 +1 @@ +[{"class":"sourmash_signature","email":"","hash_function":"0.murmur64","filename":"snap25_isoforms.fa.gz","name":"sp|P60880|SNP25_HUMAN Synaptosomal-associated protein 25 OS=Homo sapiens OX=9606 GN=SNAP25 PE=1 SV=1","license":"CC0","signatures":[{"num":0,"ksize":15,"seed":42,"max_hash":18446744073709551615,"mins":[79346193818611200,168240307515063181,357295714612793700,641852037324162118,757825252518801166,761754490623839199,782972479557803271,1134440823624009328,1238331059033764487,1254610907380596406,1309336592694834642,1488523118727361571,1544082811812416977,1568253699990021557,1658984334600444418,1709274749294532135,1725715057422578614,1748266176848414158,1794485572705419172,2005207447504439670,2093296111509890469,2156245377329835554,2307727469781016498,2333333180688968175,2506013222061606991,2511561125746392744,2718007339939391753,2731354710303524664,3025951993260798495,3075042818450726583,3222282998850823207,3251781867611681496,3285712925441827798,3372160289121841704,3410764471830675569,3469683582533228427,3568179583144872890,3817411180268475919,3832866563906945973,3926438270827644511,3972496174422951309,4198127979842701307,4229135236175220105,4563954488897712676,4574387402656656788,4858761604483042588,4861078395450885090,4862453625296885659,4938768157140792348,4940648565563262316,5025891486612838557,5284728128538357471,5291748543886084200,5312773244129532133,5489682938535252219,5565681234780460509,5597785333425413357,5749683527313916829,5763746750876013840,5801120205042806309,5930088150709737320,5966797875517575437,6054790317481672279,6064819107459850094,6124424333141656815,6147437215035628965,6177403700354481799,6192662239929692478,6277807665402207951,6416731679946736855,6568722061559040921,6574637067156095585,6708596540464185825,6761203704526339578,7133206769024413920,7176171771918367811,7376549987227901757,7414787657647879690,7476142955102908158,7550327893577082883,7792856864787756566,7872443909294406405,7929025846580547475,7936959560137202483,7961647123264386971,8039152728421069661,8064791593271213771,8348568630857036658,8359508351319989063,8361982179542468995,8393526643130404355,8463700396331089193,8478282771070816787,8660828607838927667,8993315981493348921,9183836592097640600,9226634884850375914,9245338018822560940,9368011954310053617,9418794578254213195,9442313346766117567,9460327337755655906,9463651552479664171,9650476130928207101,9852726890362380853,10020588535310946764,10231023708369489377,10417565345044462094,10547029214113848348,10589909858893293689,10716410829307121909,10999634792518244156,11061757285087944115,11117134124993193199,11131793045660178737,11190878110662832075,11266832837962751311,11291079305691832958,11333187768475619981,11483151272223137978,11672112649106120622,11679914994728284823,11681355222208589932,11810037984856334441,11863211896355109725,11936280568784502094,11990930873863737737,12056248353048495671,12117448516765825038,12135852999721843698,12237959570443212233,12496527595010127170,12537094089821085690,12545446077839030519,12644834334436341548,12719023208607745021,12733227728333486742,12871510040491615692,12985319509726495005,12997653467499793320,13144928240150978025,13394961343841494301,13459434389575211230,13461519825068144253,13544960794008603327,13629118005878225239,13671095274392390639,13695593196674902553,13865585724736991815,13916093450396042927,13945786155201993794,14043845928073073963,14120889159756865474,14241142514783205194,14531377358346009042,14889711942178191148,14933180350534022941,14940752760381923340,15015566227947745405,15249672574439659477,15278809512327941563,15307746890888995231,15504020558113784426,15582483944700946616,15586820216875746554,15594166437659914988,15605004559893578169,15625436695734261243,15633400670318138036,15795014318937641345,15810460302233344444,15995428954747916472,16076768208013694430,16195191230604588860,16199408572444659404,16214925939897936789,16251075642319005348,16313134541005961893,16350957978523201308,16541465940433933615,16592320992795060376,16698620512247263592,16816432235009607511,16941959214187235640,16946773118975929739,16988519134430486231,17033858036098508796,17047008984976409532,17058264154905588169,17064170205635237550,17113856912873235068,17119298316217558973,17621151508648857175,17641713836505334678,17695103044946679095,17751471400271654090,17901334208844605936,17970119671529042692,17974861901755320079,18052376351337125503,18054317455131746015,18150884563146655585],"md5sum":"0324da9db9490bbeacf556679bfeb7d1","molecule":"protein"}],"version":0.4},{"class":"sourmash_signature","email":"","hash_function":"0.murmur64","filename":"snap25_isoforms.fa.gz","name":"sp|P60880-2|SNP25_HUMAN Isoform 2 of Synaptosomal-associated protein 25 OS=Homo sapiens OX=9606 GN=SNAP25","license":"CC0","signatures":[{"num":0,"ksize":15,"seed":42,"max_hash":18446744073709551615,"mins":[79346193818611200,168240307515063181,357295714612793700,641852037324162118,757825252518801166,761754490623839199,854504706134318900,1134440823624009328,1238331059033764487,1309336592694834642,1488523118727361571,1709274749294532135,1725715057422578614,1748266176848414158,1794485572705419172,2005207447504439670,2024817690031712741,2093296111509890469,2156245377329835554,2307727469781016498,2333333180688968175,2511561125746392744,2718007339939391753,2731354710303524664,2871204148701307313,2909459394036752481,3025951993260798495,3075042818450726583,3222282998850823207,3251781867611681496,3285712925441827798,3410764471830675569,3568179583144872890,3817411180268475919,3926438270827644511,3955021198924481793,3972496174422951309,4131349106154158105,4198127979842701307,4229135236175220105,4563954488897712676,4574387402656656788,4858761604483042588,4861078395450885090,4862453625296885659,4938768157140792348,5025891486612838557,5263235457522296298,5284728128538357471,5291748543886084200,5312773244129532133,5327179019825988353,5565681234780460509,5597785333425413357,5749683527313916829,5801120205042806309,5930088150709737320,6124424333141656815,6147437215035628965,6177403700354481799,6192662239929692478,6277807665402207951,6416731679946736855,6574637067156095585,6708596540464185825,6761203704526339578,7029529918800785203,7080555158824450960,7133206769024413920,7158061106260034664,7176171771918367811,7376549987227901757,7414787657647879690,7476142955102908158,7550327893577082883,7696557663341238599,7792362671970898878,7792856864787756566,7872443909294406405,7929025846580547475,7936959560137202483,7961647123264386971,7978816125245756653,8039152728421069661,8064791593271213771,8348568630857036658,8359508351319989063,8361982179542468995,8383343990235193930,8393526643130404355,8463700396331089193,8478282771070816787,8660828607838927667,8981756865334327930,8993315981493348921,9040729209501707076,9183836592097640600,9226634884850375914,9245338018822560940,9368011954310053617,9418794578254213195,9442313346766117567,9460327337755655906,9650476130928207101,9852726890362380853,10020588535310946764,10124878224596917315,10231023708369489377,10547029214113848348,10589909858893293689,10716410829307121909,10795830229704962545,10999634792518244156,11000921393843515695,11117134124993193199,11131793045660178737,11266832837962751311,11291079305691832958,11333187768475619981,11483151272223137978,11630022289020613755,11672112649106120622,11679914994728284823,11681355222208589932,11810037984856334441,11863211896355109725,11990930873863737737,12117448516765825038,12135852999721843698,12237959570443212233,12496527595010127170,12537094089821085690,12545446077839030519,12644834334436341548,12719023208607745021,12733227728333486742,12871510040491615692,12985319509726495005,12997653467499793320,13015784697162252300,13137961283196836927,13144928240150978025,13394961343841494301,13459434389575211230,13461519825068144253,13544960794008603327,13559428842285846742,13629118005878225239,13671095274392390639,13695593196674902553,13865585724736991815,13916093450396042927,14043845928073073963,14120889159756865474,14241142514783205194,14531377358346009042,14648959863108810383,14798958964394729107,14889711942178191148,14933180350534022941,14940752760381923340,15015566227947745405,15249672574439659477,15278809512327941563,15307746890888995231,15582483944700946616,15586820216875746554,15594166437659914988,15605004559893578169,15625436695734261243,15633400670318138036,15743546848986775135,15795014318937641345,15810460302233344444,15861183366162130000,16118309445659596513,16195191230604588860,16199408572444659404,16214925939897936789,16313134541005961893,16350957978523201308,16541465940433933615,16592320992795060376,16816432235009607511,16941959214187235640,16946773118975929739,17033858036098508796,17047008984976409532,17058264154905588169,17064170205635237550,17621151508648857175,17641713836505334678,17695103044946679095,17751471400271654090,17901334208844605936,17970119671529042692,17974861901755320079,18008872121808310705,18052376351337125503,18054317455131746015,18150884563146655585,18170310193389148668],"md5sum":"cff76481ff27f6da26ccc726309bf3dc","molecule":"protein"}],"version":0.4},{"class":"sourmash_signature","email":"","hash_function":"0.murmur64","filename":"snap25a_mxe_exon_human.part_001.fa.gz","name":"snap25a_mxe_exon_human","license":"CC0","signatures":[{"num":0,"ksize":15,"seed":42,"max_hash":18446744073709551615,"mins":[79346193818611200,854504706134318900,2005207447504439670,2024817690031712741,2871204148701307313,2909459394036752481,3817411180268475919,3955021198924481793,5263235457522296298,5312773244129532133,5327179019825988353,6177403700354481799,7029529918800785203,7158061106260034664,7696557663341238599,7792362671970898878,8383343990235193930,9040729209501707076,10124878224596917315,10795830229704962545,11000921393843515695,11630022289020613755,13015784697162252300,13137961283196836927,13559428842285846742,14648959863108810383,18008872121808310705,18170310193389148668],"md5sum":"e0c1adac94c5d573c5e4fca31c02e843","abundances":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"molecule":"protein"}],"version":0.4},{"class":"sourmash_signature","email":"","hash_function":"0.murmur64","filename":"snap25b_mxe_exon_human.part_001.fa.gz","name":"snap25b_mxe_exon_human","license":"CC0","signatures":[{"num":0,"ksize":15,"seed":42,"max_hash":18446744073709551615,"mins":[79346193818611200,782972479557803271,1254610907380596406,1544082811812416977,1658984334600444418,2005207447504439670,2506013222061606991,3372160289121841704,3817411180268475919,3832866563906945973,4940648565563262316,5312773244129532133,5966797875517575437,6054790317481672279,6064819107459850094,6177403700354481799,6568722061559040921,9463651552479664171,10417565345044462094,11190878110662832075,11936280568784502094,12056248353048495671,13945786155201993794,15504020558113784426,15995428954747916472,16076768208013694430,16988519134430486231,17119298316217558973],"md5sum":"86822d529534d380abc86b36449b2c85","abundances":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"molecule":"protein"}],"version":0.4}] \ No newline at end of file diff --git a/src/python/tests/test_multisearch.py b/src/python/tests/test_multisearch.py index b6bddbfd..dfc65ee2 100644 --- a/src/python/tests/test_multisearch.py +++ b/src/python/tests/test_multisearch.py @@ -13,6 +13,10 @@ ) +def float_round(string: str, ndigits=None): + return round(float(string), ndigits) + + def test_installed(runtmp): with pytest.raises(utils.SourmashCommandFailed): runtmp.sourmash("scripts", "multisearch") @@ -49,6 +53,8 @@ def test_simple_no_ani(runtmp, zip_query, zip_db): print(dd) for idx, row in dd.items(): + assert not ("prob_overlap" in row) + # identical? if row["match_name"] == row["query_name"]: assert row["query_md5"] == row["match_md5"], row @@ -64,14 +70,12 @@ def test_simple_no_ani(runtmp, zip_query, zip_db): # confirm hand-checked numbers q = row["query_name"].split()[0] m = row["match_name"].split()[0] - cont = float(row["containment"]) - jaccard = float(row["jaccard"]) - maxcont = float(row["max_containment"]) + cont = float_round(row["containment"], 4) + jaccard = float_round(row["jaccard"], 4) + maxcont = float_round(row["max_containment"], 4) + intersect_hashes = int(row["intersect_hashes"]) - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}") if q == "NC_011665.1" and m == "NC_009661.1": @@ -87,7 +91,7 @@ def test_simple_no_ani(runtmp, zip_query, zip_db): assert intersect_hashes == 2529 -def test_simple_ani(runtmp, zip_query, zip_db, indexed_query, indexed_against): +def test_simple_prob_overlap(runtmp, zip_query, zip_db, indexed_query, indexed_against): # test basic execution! query_list = runtmp.output("query.txt") against_list = runtmp.output("against.txt") @@ -113,6 +117,102 @@ def test_simple_ani(runtmp, zip_query, zip_db, indexed_query, indexed_against): if indexed_against: against_list = index_siglist(runtmp, against_list, runtmp.output("db")) + runtmp.sourmash( + "scripts", "multisearch", query_list, against_list, "-o", output, "--prob" + ) + assert os.path.exists(output) + + df = pandas.read_csv(output) + assert len(df) == 5 + + dd = df.to_dict(orient="index") + print(dd) + + for idx, row in dd.items(): + # identical? + if row["match_name"] == row["query_name"]: + assert row["query_md5"] == row["match_md5"], row + assert float(row["containment"] == 1.0) + assert float(row["jaccard"] == 1.0) + assert float(row["max_containment"] == 1.0) + assert "query_containment_ani" not in row + assert "match_containment_ani" not in row + assert "average_containment_ani" not in row + assert "max_containment_ani" not in row + + if row["match_name"] == "NC_011665.1": + assert float_round(row["prob_overlap"], 7) == 4.67e-05 + assert float_round(row["prob_overlap_adjusted"], 7) == 0.0004206 + assert float_round(row["containment_adjusted"], 4) == 2377.5947 + assert float_round(row["containment_adjusted_log10"], 4) == 2377.5947 + assert float_round(row["tf_idf_score"], 4) == 1.4974 + + else: + # confirm hand-checked numbers + q = row["query_name"].split()[0] + m = row["match_name"].split()[0] + cont = float_round(row["containment"], 4) + jaccard = float_round(row["jaccard"], 4) + maxcont = float_round(row["max_containment"], 4) + prob_overlap = float_round(row["prob_overlap"], 7) + prob_overlap_adjusted = float_round(row["prob_overlap_adjusted"], 7) + containment_adjusted = float_round(row["containment_adjusted"], 4) + containment_adjusted_log10 = float_round( + row["containment_adjusted_log10"], 4 + ) + tf_idf_score = float_round(row["tf_idf_score"], 4) + intersect_hashes = int(row["intersect_hashes"]) + + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}") + + if q == "NC_011665.1" and m == "NC_009661.1": + assert jaccard == 0.3207 + assert cont == 0.4828 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + assert prob_overlap == 2.26e-05 + assert prob_overlap_adjusted == 0.0002031 + assert containment_adjusted == 2377.5947 + assert containment_adjusted_log10 == 3.3761 + assert tf_idf_score == 0.6217 + + if q == "NC_009661.1" and m == "NC_011665.1": + assert jaccard == 0.3207 + assert cont == 0.4885 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + assert prob_overlap == 2.26e-05 + assert prob_overlap_adjusted == 0.0002031 + assert containment_adjusted == 2405.6096 + assert containment_adjusted_log10 == 3.3812 + assert tf_idf_score == 0.6290 + + +def test_simple_ani(runtmp, zip_query, zip_db, indexed_query, indexed_against): + # test basic execution! + query_list = runtmp.output("query.txt") + against_list = runtmp.output("against.txt") + + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + make_file_list(query_list, [sig2, sig47, sig63]) + make_file_list(against_list, [sig2, sig47, sig63]) + + output = runtmp.output("out.csv") + + if zip_db: + against_list = zip_siglist(runtmp, against_list, runtmp.output("db.zip")) + if zip_query: + query_list = zip_siglist(runtmp, query_list, runtmp.output("query.zip")) + + if indexed_query: + query_list = index_siglist(runtmp, query_list, runtmp.output("q_db")) + + if indexed_against: + against_list = index_siglist(runtmp, against_list, runtmp.output("db")) + runtmp.sourmash( "scripts", "multisearch", query_list, against_list, "-o", output, "--ani" ) @@ -125,6 +225,8 @@ def test_simple_ani(runtmp, zip_query, zip_db, indexed_query, indexed_against): print(dd) for idx, row in dd.items(): + assert not ("prob_overlap" in row) + # identical? if row["match_name"] == row["query_name"]: assert row["query_md5"] == row["match_md5"], row @@ -140,22 +242,15 @@ def test_simple_ani(runtmp, zip_query, zip_db, indexed_query, indexed_against): # confirm hand-checked numbers q = row["query_name"].split()[0] m = row["match_name"].split()[0] - cont = float(row["containment"]) - jaccard = float(row["jaccard"]) - maxcont = float(row["max_containment"]) + cont = float_round(row["containment"], 4) + jaccard = float_round(row["jaccard"], 4) + maxcont = float_round(row["max_containment"], 4) intersect_hashes = int(row["intersect_hashes"]) - q1_ani = float(row["query_containment_ani"]) - q2_ani = float(row["match_containment_ani"]) - avg_ani = float(row["average_containment_ani"]) - max_ani = float(row["max_containment_ani"]) + q1_ani = float_round(row["query_containment_ani"], 4) + q2_ani = float_round(row["match_containment_ani"], 4) + avg_ani = float_round(row["average_containment_ani"], 4) + max_ani = float_round(row["max_containment_ani"], 4) - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) - q1_ani = round(q1_ani, 4) - q2_ani = round(q2_ani, 4) - avg_ani = round(avg_ani, 4) - max_ani = round(max_ani, 4) print( q, m, @@ -660,6 +755,7 @@ def test_empty_query(runtmp, capfd): captured = capfd.readouterr() print(captured.err) assert "No query signatures loaded, exiting." in captured.err + # @CTB def test_nomatch_query_warn(runtmp, capfd, zip_query): @@ -879,6 +975,9 @@ def test_simple_prot(runtmp): print(dd) for idx, row in dd.items(): + # Make sure prob_overlap is only run when requested + assert not ("prob_overlap" in row) + # identical? if row["match_name"] == row["query_name"]: assert row["query_md5"] == row["match_md5"], row @@ -894,22 +993,15 @@ def test_simple_prot(runtmp): # confirm hand-checked numbers q = row["query_name"].split()[0] m = row["match_name"].split()[0] - cont = float(row["containment"]) - jaccard = float(row["jaccard"]) - maxcont = float(row["max_containment"]) + cont = float_round(row["containment"], 4) + jaccard = float_round(row["jaccard"], 4) + maxcont = float_round(row["max_containment"], 4) intersect_hashes = int(row["intersect_hashes"]) - q1_ani = float(row["query_containment_ani"]) - q2_ani = float(row["match_containment_ani"]) - avg_ani = float(row["average_containment_ani"]) - max_ani = float(row["max_containment_ani"]) + q1_ani = float_round(row["query_containment_ani"], 4) + q2_ani = float_round(row["match_containment_ani"], 4) + avg_ani = float_round(row["average_containment_ani"], 4) + max_ani = float_round(row["max_containment_ani"], 4) - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) - q1_ani = round(q1_ani, 4) - q2_ani = round(q2_ani, 4) - avg_ani = round(avg_ani, 4) - max_ani = round(max_ani, 4) print( q, m, @@ -944,6 +1036,133 @@ def test_simple_prot(runtmp): assert max_ani == 0.886 +def test_prob_overlap_prot_with_abundance(runtmp): + # test basic execution with protein sigs + sigs = get_test_data("snap25.protein.k5.sig") + + output = runtmp.output("out.csv") + + runtmp.sourmash( + "scripts", + "multisearch", + sigs, + sigs, + "-o", + output, + "--moltype", + "protein", + "-k", + "5", + "--scaled", + "1", + "--prob", + ) + assert os.path.exists(output) + + df = pandas.read_csv(output) + assert len(df) == 16 + + dd = df.to_dict(orient="index") + print(dd) + + for idx, row in dd.items(): + # identical? + if row["match_name"] == row["query_name"]: + assert row["query_md5"] == row["match_md5"], row + assert float(row["containment"] == 1.0) + assert float(row["jaccard"] == 1.0) + assert float(row["max_containment"] == 1.0) + + else: + # confirm hand-checked numbers + q = row["query_name"].split()[0] + m = row["match_name"].split()[0] + cont = float_round(row["containment"], 4) + jaccard = float_round(row["jaccard"], 4) + maxcont = float_round(row["max_containment"], 4) + intersect_hashes = int(row["intersect_hashes"]) + prob_overlap = float_round(row["prob_overlap"], 8) + prob_overlap_adjusted = float_round(row["prob_overlap_adjusted"], 8) + containment_adjusted = float_round(row["containment_adjusted"], 4) + containment_adjusted_log10 = float_round( + row["containment_adjusted_log10"], 4 + ) + tf_idf_score = float_round(row["tf_idf_score"], 4) + + print( + q, + m, + f"{jaccard:.04}", + f"{cont:.04}", + f"{maxcont:.04}", + intersect_hashes, + f"{prob_overlap:.04}", + f"{prob_overlap_adjusted:.04}", + f"{containment_adjusted:.04}", + f"{containment_adjusted_log10:.04}", + f"{tf_idf_score:.04}", + ) + + if q == "snap25a_mxe_exon_human" and m == "snap25b_mxe_exon_human": + assert jaccard == 0.098 + assert cont == 0.1786 + assert maxcont == 0.1786 + assert intersect_hashes == 5 + assert prob_overlap == 9.21e-05 + assert prob_overlap_adjusted == 0.0014736 + assert containment_adjusted == 121.1808 + assert containment_adjusted_log10 == 2.0834 + assert tf_idf_score == 0.1786 + + if q == "snap25b_mxe_exon_human" and m == "snap25a_mxe_exon_human": + # Check that inverse for snap25b vs snap25a exon is true + assert jaccard == 0.098 + assert cont == 0.1786 + assert maxcont == 0.1786 + assert intersect_hashes == 5 + assert prob_overlap == 9.21e-05 + assert prob_overlap_adjusted == 0.0014736 + assert containment_adjusted == 121.1808 + assert containment_adjusted_log10 == 2.0834 + assert tf_idf_score == 0.1786 + + if q == "snap25a_mxe_exon_human" and m == "sp|P60880-2|SNP25_HUMAN": + # P60880-2 is isoform SNAP25A, including the snap25a exon + assert jaccard == 0.1386 + assert cont == 1.0 + assert maxcont == 1.0 + assert intersect_hashes == 28 + assert prob_overlap == 0.00051576 + assert prob_overlap_adjusted == 0.00825213 + assert containment_adjusted == 121.1808 + assert containment_adjusted_log10 == 2.0834 + assert tf_idf_score == 1.4196 + + if q == "snap25b_mxe_exon_human" and m == "sp|P60880|SNP25_HUMAN": + # P60880 is isoform SNAP25B, including the snap25b exon + assert jaccard == 0.1386 + assert cont == 1.0 + assert maxcont == 1.0 + assert intersect_hashes == 28 + assert prob_overlap == 0.00051576 + assert prob_overlap_adjusted == 0.00825213 + assert containment_adjusted == 121.1808 + assert containment_adjusted_log10 == 2.0834 + assert tf_idf_score == 1.4196 + + if q == "sp|P60880-2|SNP25_HUMAN" and m == "sp|P60880|SNP25_HUMAN": + # P60880 is isoform SNAP25B, including the snap25b exon + assert jaccard == 0.7339 + assert cont == 0.8465 + assert maxcont == 0.8465 + assert intersect_hashes == 171 + assert prob_overlap == 0.00314981 + assert prob_overlap_adjusted == 0.05039695 + assert containment_adjusted == 16.7973 + assert containment_adjusted_log10 == 1.2252 + assert tf_idf_score == 1.2663 + + def test_simple_dayhoff(runtmp): # test basic execution with dayhoff sigs sigs = get_test_data("dayhoff.zip") @@ -989,22 +1208,15 @@ def test_simple_dayhoff(runtmp): # confirm hand-checked numbers q = row["query_name"].split()[0] m = row["match_name"].split()[0] - cont = float(row["containment"]) - jaccard = float(row["jaccard"]) - maxcont = float(row["max_containment"]) + cont = float_round(row["containment"], 4) + jaccard = float_round(row["jaccard"], 4) + maxcont = float_round(row["max_containment"], 4) intersect_hashes = int(row["intersect_hashes"]) - q1_ani = float(row["query_containment_ani"]) - q2_ani = float(row["match_containment_ani"]) - avg_ani = float(row["average_containment_ani"]) - max_ani = float(row["max_containment_ani"]) + q1_ani = float_round(row["query_containment_ani"], 4) + q2_ani = float_round(row["match_containment_ani"], 4) + avg_ani = float_round(row["average_containment_ani"], 4) + max_ani = float_round(row["max_containment_ani"], 4) - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) - q1_ani = round(q1_ani, 4) - q2_ani = round(q2_ani, 4) - avg_ani = round(avg_ani, 4) - max_ani = round(max_ani, 4) print( q, m, @@ -1084,22 +1296,15 @@ def test_simple_hp(runtmp): # confirm hand-checked numbers q = row["query_name"].split()[0] m = row["match_name"].split()[0] - cont = float(row["containment"]) - jaccard = float(row["jaccard"]) - maxcont = float(row["max_containment"]) + cont = float_round(row["containment"], 4) + jaccard = float_round(row["jaccard"], 4) + maxcont = float_round(row["max_containment"], 4) intersect_hashes = int(row["intersect_hashes"]) - q1_ani = float(row["query_containment_ani"]) - q2_ani = float(row["match_containment_ani"]) - avg_ani = float(row["average_containment_ani"]) - max_ani = float(row["max_containment_ani"]) + q1_ani = float_round(row["query_containment_ani"], 4) + q2_ani = float_round(row["match_containment_ani"], 4) + avg_ani = float_round(row["average_containment_ani"], 4) + max_ani = float_round(row["max_containment_ani"], 4) - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) - q1_ani = round(q1_ani, 4) - q2_ani = round(q2_ani, 4) - avg_ani = round(avg_ani, 4) - max_ani = round(max_ani, 4) print( q, m, diff --git a/src/python/tests/test_sketch.py b/src/python/tests/test_sketch.py index 3c610a56..b8cf4e23 100644 --- a/src/python/tests/test_sketch.py +++ b/src/python/tests/test_sketch.py @@ -961,7 +961,7 @@ def test_manysketch_prefix2(runtmp, capfd): dna_prefix = os.path.join( fa_path, "short" ) # need to avoid matching short-protein.fa - prot_prefix = os.path.join(fa_path, "*protein") + prot_prefix = os.path.join(fa_path, "*protein.fa") zip_exclude = os.path.join(fa_path, "*zip") # make prefix input file @@ -1042,6 +1042,7 @@ def test_manysketch_prefix2(runtmp, capfd): for sig in sigs: assert sig.name in expected_signames if sig.name == "short": + # minhash is not defined? How does this test work? - @olgabot assert sig, minhash.hashes == sig1.minhash.hashes if sig.name == "short_protein": assert sig == sig2 @@ -1118,7 +1119,7 @@ def test_manysketch_prefix_duplicated_force(runtmp, capfd): dna_prefix = os.path.join( fa_path, "short" ) # need to avoid matching short-protein.fa - prot_prefix = os.path.join(fa_path, "*protein") + prot_prefix = os.path.join(fa_path, "*protein*fa") zip_exclude = os.path.join(fa_path, "*zip") # make prefix input file diff --git a/src/search_significance.rs b/src/search_significance.rs new file mode 100644 index 00000000..275fb8c9 --- /dev/null +++ b/src/search_significance.rs @@ -0,0 +1,210 @@ +// Functions to compute statisical signifiance of search results + +use rayon::prelude::*; + +use crate::utils::multicollection::SmallSignature; +use sourmash::signature::SigsTrait; +use sourmash::sketch::minhash::KmerMinHash; +use sourmash::Error; +use std::collections::{HashMap, HashSet}; +use std::fmt::{self, Display, Formatter}; + +pub enum Normalization { + // L1 norm is the equivalent of frequencies/probabilities, as the counts + // are divided by the length of the Vec<> object, or mathematically, the + // number of items in the Vec object, assuming unit length for each + L1, + + // L2 norm divides the counts by the sum of squares of all counts + // L2 norm is the Euclidean distance of the counts in N-dimensional vector space + // When track_abundance=False, L1 and L2 norms are equivalent, since all "counts" + // are 1, and even if you square it, 1^2 = 1 + L2, +} + +impl Display for Normalization { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::L1 => write!(f, "L1 Normalization"), + Self::L2 => write!(f, "L2 Normalization"), + } + } +} + +pub fn get_hash_frequencies<'a>( + minhash: &KmerMinHash, + normalization: Option, +) -> HashMap { + let minhash_abunds: HashMap = minhash + .to_vec_abunds() + .into_par_iter() + .map(|(hashval, abund)| (hashval, abund as f64)) + .collect(); + + let abund_normalization: f64 = match normalization { + Some(Normalization::L1) => minhash.sum_abunds() as f64, + Some(Normalization::L2) => minhash_abunds + .par_iter() + .map(|(_hashval, abund)| abund * abund) + .sum::() as f64, + // TODO: this should probably be an error + _ => 0.0, + }; + + let frequencies: HashMap = HashMap::from( + minhash_abunds + .par_iter() + .map(|(hashval, abund)| + // TODO: add a match statement here to error out properly if the hashval was not found + // in the minhash_abunds for some reason (shouldn't happen but ... computers be crazy) + ( + *hashval, + abund / abund_normalization + )) + .collect::>(), + ); + + return frequencies; +} + +// #[cfg(feature = "maths")] +pub fn get_prob_overlap( + hashvals: &Vec, + query_frequencies: &HashMap, + against_frequencies: &HashMap, +) -> f64 { + // It's not guaranteed to me that the MinHashes from the query and database are in the same order, so iterate over one of them + // and use a hashmap to retrieve the frequency value of the other + let prob_overlap = hashvals + .par_iter() + .map(|hashval| query_frequencies[hashval] * against_frequencies[hashval]) + .sum(); + + return prob_overlap; +} + +// TODO: How to accept SourmashSignature objects? Signature.minhash is Option<&KmerMinHash>, +// so it's not guaranteed for a SourmashSignature to have a minhash object. Is there a way to +// only accept SourmashSignature objects that have `.minhash` present? +pub fn merge_all_minhashes(sigs: &Vec) -> Result { + if sigs.is_empty() { + eprintln!("Signature list is empty"); + std::process::exit(1); + } + + let first_sig = &sigs[0]; + + // Use the first signature to instantiate the merging of all minhashes + let mut combined_mh = KmerMinHash::new( + first_sig.minhash.scaled().try_into().unwrap(), + first_sig.minhash.ksize().try_into().unwrap(), + first_sig.minhash.hash_function(), + // accessing first_sig.minhash.seed is private -> hardcode instead + first_sig.minhash.seed(), + first_sig.minhash.track_abundance(), + first_sig.minhash.num(), + ); + + let hashes_with_abund: Vec<(u64, u64)> = sigs + .par_iter() + .map(|sig| sig.minhash.to_vec_abunds()) + .flatten() + .collect(); + + _ = combined_mh.add_many_with_abund(&hashes_with_abund); + + Ok(combined_mh) +} + +pub fn compute_inverse_document_frequency( + against_merged_mh: &KmerMinHash, + againsts: &Vec, + smooth_idf: Option, +) -> HashMap { + // Compute inverse document frequency (IDF) of all + // Inverse document frequency tells us how unique this hashval is to the query database + // When the value is near 0, then this hashval appears in all signatures + // When the value is very large, equal to the number of signatures, then the hashval is + // unique to a single signature + + // Total number of documents in the corpus + let n_signatures = againsts.len() as f64; + + let againsts_hashes: Vec> = againsts + .par_iter() + .map(|sig| HashSet::from_iter(sig.minhash.iter_mins())) + .collect::>>(); + + // Number of documents where hashvals appear + // hashmap of: { hashval: n_sigs_with_hashval } + let document_frequency: HashMap<&u64, f64> = HashMap::from( + against_merged_mh + .iter_mins() + .par_bridge() + .map(|hashval| { + ( + hashval, + againsts_hashes + .par_iter() + .map(|hashset| f64::from(u32::from(hashset.contains(&hashval)))) + .sum(), + ) + }) + .collect::>(), + ); + + let inverse_document_frequency: HashMap = HashMap::from( + document_frequency + .par_iter() + .map(|(hashval, n_sigs_with_hashval)| { + ( + **hashval, + match smooth_idf { + // Add 1 to not totally ignore terms that appear in all documents + // scikit-learn documentation (assumed to implement best practices for document classification): + // > "The effect of adding “1” to the idf in the equation above is that terms with zero idf, + // > i.e., terms that occur in all documents in a training set, will not be entirely ignored." + // Source: https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.TfidfTransformer.html + Some(true) => { + ((1.0 + n_signatures) / (1.0 + n_sigs_with_hashval)).ln() + 1.0 + } + Some(false) => (n_signatures / (n_sigs_with_hashval)).ln() + 1.0, + _ => 1.0, + }, + ) + }) + .collect::>(), + ); + + return inverse_document_frequency; +} + +pub fn get_term_frequency_inverse_document_frequency( + hashvals: &Vec, + query_term_frequencies: &HashMap, + inverse_document_frequency: &HashMap, +) -> f64 { + // Implementation of tf-idf for hashvals and signatures + // https://en.wikipedia.org/wiki/Tf%E2%80%93idf + // Square the abundances to use an L2 norm -> why? + // Because this is the default setting in scikit-learn's battle-tested tf-idf methods: + // https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.TfidfTransformer.html + // https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html + + // Multiply each hashval's term frequency and inverse document frequency, and sum the products + let tf_idf: HashMap<&u64, f64> = HashMap::from( + hashvals + .par_iter() + .map(|hashval| { + ( + hashval, + query_term_frequencies[hashval] * inverse_document_frequency[hashval], + ) + }) + .collect::>(), + ); + + let tf_idf_score: f64 = tf_idf.values().sum(); + + return tf_idf_score; +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 2044b26f..a039b9cc 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -11,6 +11,7 @@ use camino::Utf8PathBuf as PathBuf; use csv::Writer; use glob::glob; use serde::{Deserialize, Serialize}; +// use rust_decimal::{MathematicalOps, Decimal}; use std::cmp::{Ordering, PartialOrd}; use std::collections::BinaryHeap; use std::fs::{create_dir_all, File}; @@ -1122,6 +1123,20 @@ pub struct MultiSearchResult { pub average_containment_ani: Option, #[serde(skip_serializing_if = "Option::is_none")] pub max_containment_ani: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub prob_overlap: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prob_overlap_adjusted: Option, + #[serde(skip_serializing_if = "Option::is_none")] + // max_containment / prob_overlap -> Bigger means less likely to be random + pub containment_adjusted: Option, + #[serde(skip_serializing_if = "Option::is_none")] + // logged version is easier to plot/prioritize + pub containment_adjusted_log10: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tf_idf_score: Option, } pub fn open_stdout_or_file(output: Option) -> Box {