diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..a05a706a --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +rustdocflags = ["--document-private-items"] diff --git a/Cargo.lock b/Cargo.lock index 8a17ac00..3e9bc27e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -713,9 +713,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] @@ -740,9 +740,9 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" [[package]] name = "libloading" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", "windows-targets", @@ -772,9 +772,9 @@ dependencies = [ [[package]] name = "libz-sys" -version = "1.1.18" +version = "1.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c15da26e5af7e25c90b37a2d75cdbf940cf4a55316de9d84c679c9b8bfabf82e" +checksum = "fdc53a7799a7496ebc9fd29f31f7df80e83c9bda5299768af5f9e59eeea74647" dependencies = [ "cc", "pkg-config", @@ -795,9 +795,9 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lz4-sys" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9764018d143cc854c9f17f0b907de70f14393b1f502da6375dce70f00514eb3" +checksum = "109de74d5d2353660401699a4174a4ff23fcc649caf553df71933c7fb45ad868" dependencies = [ "cc", "libc", @@ -1551,8 +1551,7 @@ checksum = "bceb57dc07c92cdae60f5b27b3fa92ecaaa42fe36c55e22dbfb0b44893e0b1f7" [[package]] name = "sourmash" version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8655e639cc4a32fa1422629c9b4ff603ee09cf6d04a97eacd37594382472d437" +source = "git+https://github.com/sourmash-bio/sourmash.git?branch=more_rs_updates#affae94848a79a57b0b7cef801d41054e60458ee" dependencies = [ "az", "byteorder", @@ -1816,19 +1815,20 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", @@ -1841,9 +1841,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1851,9 +1851,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", @@ -1864,15 +1864,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 15188b44..abc937f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,8 @@ crate-type = ["cdylib"] pyo3 = { version = "0.22.2", features = ["extension-module", "anyhow"] } rayon = "1.10.0" serde = { version = "1.0.208", features = ["derive"] } -sourmash = { version = "0.15.0", features = ["branchwater"] } +sourmash = { git = "https://github.com/sourmash-bio/sourmash.git", branch = "more_rs_updates", features = ["branchwater"] } +#sourmash = { version = "0.15.0", features = ["branchwater"] } serde_json = "1.0.125" niffler = "2.4.0" log = "0.4.22" diff --git a/pyproject.toml b/pyproject.toml index 70a26945..f9571c04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,9 @@ authors = [ requires = ["maturin>=1.4.0,<2"] build-backend = "maturin" +[project.entry-points."sourmash.load_from"] +collection_reader = "sourmash_plugin_branchwater:load_collection" + [project.entry-points."sourmash.cli_script"] manysearch = "sourmash_plugin_branchwater:Branchwater_Manysearch" multisearch = "sourmash_plugin_branchwater:Branchwater_Multisearch" diff --git a/src/branch_api.rs b/src/branch_api.rs new file mode 100644 index 00000000..d89838bc --- /dev/null +++ b/src/branch_api.rs @@ -0,0 +1,165 @@ +/// Lower-level Python API implementation for sourmash_plugin_branchwater +use pyo3::prelude::*; + +use crate::utils::build_selection; +use crate::utils::load_collection; +use crate::utils::ReportType; +use crate::utils::multicollection::MultiCollection; +use sourmash::collection::Collection; +use sourmash::manifest::{Manifest, Record}; +use pyo3::types::{IntoPyDict, PyDict, PyList}; + +#[pyclass] +pub struct BranchRecord { + record: Record, +} + +#[pymethods] +impl BranchRecord { + pub fn get_name(&self) -> PyResult { + Ok(self.record.name().clone()) + } + + #[getter] + pub fn get_as_row<'py>(&self, py: Python<'py>) -> PyResult> { + let dict = { + let key_vals: Vec<(&str, PyObject)> = vec![ + ("ksize", self.record.ksize().to_object(py)), + ("moltype", self.record.moltype().to_string().to_object(py)), + ("scaled", self.record.scaled().to_object(py)), + ("num", self.record.num().to_object(py)), + ("with_abundance", self.record.with_abundance().to_object(py)), + ("n_hashes", self.record.n_hashes().to_object(py)), + ]; + key_vals.into_py_dict_bound(py) + }; + Ok(dict) + } +} + +/* +impl IntoPyDict for I +where + T: PyDictItem + I: IntoIterator +fn into_py_dict(self, py: Python<'_>) -> Bound<'_, PyDict> { + let dict = PyDict::new(py); + for item in self { + dict.set_item(item.key(), item.value()) + .expect("Failed to set_item on dict"); + } + dict +} +} +*/ + +#[pyclass] +pub struct BranchManifest { + manifest: Manifest, +} + +#[pymethods] +impl BranchManifest { + pub fn __len__(&self) -> PyResult { + Ok(self.manifest.len()) + } + pub fn _check_row_values(&self) -> PyResult { + Ok(true) + } + #[getter] + pub fn get_rows<'py>(&self, py: Python<'py>) -> PyResult>> { + let res: Vec<_> = self.manifest.iter().map(|x| { BranchRecord { + record: x.clone(), + }.get_as_row(py).unwrap() + }).collect(); + + Ok(res) + } +} + +#[pyclass] +pub struct BranchCollection { + #[pyo3(get)] + pub location: String, + + #[pyo3(get)] + pub is_database: bool, + + #[pyo3(get)] + pub has_manifest: bool, + + collection: MultiCollection, +} + +#[pymethods] +impl BranchCollection { + pub fn __len__(&self) -> PyResult { + Ok(self.collection.len()) + } + + #[getter] + pub fn get_manifest(&self) -> PyResult> { + let manifest: Manifest = self.collection.manifest().clone(); + let obj = + Python::with_gil(|py| Py::new(py, BranchManifest { manifest: manifest }).unwrap()); + Ok(obj) + } + pub fn get_first_record(&self) -> PyResult> { + let records: Vec<_> = self.collection.iter().collect(); + let first_record = records.first().unwrap().1; + + // @CTB: can I turn this into something automatic? + let obj = Python::with_gil(|py| { + Py::new( + py, + BranchRecord { + record: first_record.clone(), + }, + ) + .unwrap() + }); + Ok(obj) + } + + #[getter] + pub fn get_rows(&self) -> PyResult> { + let records: Vec<_> = self.collection.iter().collect(); + + let obj = records + .iter() + .map(|x| { + BranchRecord { + record: x.1.clone(), + } + }) + .collect(); + + // @CTB: this does the GIL grabbing as needed? + Ok(obj) + } +} + +#[pyfunction] +pub fn api_load_collection( + location: String, + ksize: u8, + scaled: usize, + moltype: String, +) -> PyResult> { + let selection = build_selection(ksize, scaled, &moltype); + + let collection = load_collection(&location, &selection, ReportType::Query, true).unwrap(); + let obj = Python::with_gil(|py| { + Py::new( + py, + BranchCollection { + location: location, + collection, + is_database: false, + has_manifest: true, + }, + ) + .unwrap() + }); + Ok(obj) +} diff --git a/src/fastgather.rs b/src/fastgather.rs index 46512025..e4271249 100644 --- a/src/fastgather.rs +++ b/src/fastgather.rs @@ -33,7 +33,7 @@ pub fn fastgather( ) } // get single query sig and minhash - let query_sig = query_collection.sig_for_dataset(0)?; // need this for original md5sum + let query_sig = query_collection.get_first_sig().unwrap(); let query_sig_ds = query_sig.clone().select(selection)?; // downsample let query_mh = match query_sig_ds.minhash() { Some(query_mh) => query_mh, diff --git a/src/fastmultigather.rs b/src/fastmultigather.rs index 22b9efaa..07dc22d2 100644 --- a/src/fastmultigather.rs +++ b/src/fastmultigather.rs @@ -69,11 +69,11 @@ pub fn fastmultigather( let skipped_paths = AtomicUsize::new(0); let failed_paths = AtomicUsize::new(0); - query_collection.par_iter().for_each(|(_idx, record)| { + query_collection.par_iter().for_each(|(c, _idx, record)| { // increment counter of # of queries. q: could we instead use the _idx from par_iter(), or will it vary based on thread? let _i = processed_queries.fetch_add(1, atomic::Ordering::SeqCst); // Load query sig (downsampling happens here) - match query_collection.sig_from_record(record) { + match c.sig_from_record(record) { Ok(query_sig) => { let name = query_sig.name(); let prefix = name.split(' ').next().unwrap_or_default().to_string(); @@ -133,7 +133,7 @@ pub fn fastmultigather( if let Ok(mut file) = File::create(&sig_filename) { let unique_hashes: HashSet = hashes.into_iter().collect(); let mut new_mh = KmerMinHash::new( - query_mh.scaled().try_into().unwrap(), + query_mh.scaled(), query_mh.ksize().try_into().unwrap(), query_mh.hash_function().clone(), query_mh.seed(), diff --git a/src/lib.rs b/src/lib.rs index 7d623ea7..00b38f75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,14 @@ -/// Python interface Rust code for sourmash_plugin_branchwater. +//! Rust-to-Pyton interface code for sourmash_plugin_branchwater, using pyo3. +//! +//! If you're using Rust, you're probably most interested in +//! [utils](utils/index.html) + use pyo3::prelude::*; #[macro_use] extern crate simple_error; +mod branch_api; mod utils; use crate::utils::build_selection; use crate::utils::is_revindex_database; @@ -106,6 +111,7 @@ fn do_fastgather( } #[pyfunction] +#[allow(clippy::too_many_arguments)] #[pyo3(signature = (query_filenames, siglist_path, threshold_bp, ksize, scaled, moltype, output_path=None, save_matches=false))] fn do_fastmultigather( query_filenames: String, @@ -322,8 +328,11 @@ fn do_cluster( } } +/// Module interface for the `sourmash_plugin_branchwater` extension module. + #[pymodule] fn sourmash_plugin_branchwater(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + // top level 'scripts' commands m.add_function(wrap_pyfunction!(do_manysearch, m)?)?; m.add_function(wrap_pyfunction!(do_fastgather, m)?)?; m.add_function(wrap_pyfunction!(do_fastmultigather, m)?)?; @@ -334,5 +343,10 @@ fn sourmash_plugin_branchwater(_py: Python, m: &Bound<'_, PyModule>) -> PyResult m.add_function(wrap_pyfunction!(do_multisearch, m)?)?; m.add_function(wrap_pyfunction!(do_pairwise, m)?)?; m.add_function(wrap_pyfunction!(do_cluster, m)?)?; + + // lower level API stuff + m.add_class::()?; + m.add_function(wrap_pyfunction!(branch_api::api_load_collection, m)?)?; + Ok(()) } diff --git a/src/manysearch.rs b/src/manysearch.rs index a200b52d..5a585597 100644 --- a/src/manysearch.rs +++ b/src/manysearch.rs @@ -58,7 +58,7 @@ pub fn manysearch( let send = against_collection .par_iter() - .filter_map(|(_idx, record)| { + .filter_map(|(coll, _idx, record)| { let i = processed_sigs.fetch_add(1, atomic::Ordering::SeqCst); if i % 1000 == 0 && i > 0 { eprintln!("Processed {} search sigs", i); @@ -67,7 +67,7 @@ pub fn manysearch( let mut results = vec![]; // against downsampling happens here - match against_collection.sig_from_record(record) { + match coll.sig_from_record(record) { Ok(against_sig) => { if let Some(against_mh) = against_sig.minhash() { for query in query_sketchlist.iter() { diff --git a/src/mastiff_manygather.rs b/src/mastiff_manygather.rs index ea99153c..eb665cb6 100644 --- a/src/mastiff_manygather.rs +++ b/src/mastiff_manygather.rs @@ -54,12 +54,12 @@ pub fn mastiff_manygather( let send = query_collection .par_iter() - .filter_map(|(_idx, record)| { + .filter_map(|(coll, _idx, record)| { let threshold = threshold_bp / selection.scaled()? as usize; let ksize = selection.ksize()?; // query downsampling happens here - match query_collection.sig_from_record(record) { + match coll.sig_from_record(record) { Ok(query_sig) => { let mut results = vec![]; if let Some(query_mh) = query_sig.minhash() { diff --git a/src/mastiff_manysearch.rs b/src/mastiff_manysearch.rs index fac364c6..dee55e53 100644 --- a/src/mastiff_manysearch.rs +++ b/src/mastiff_manysearch.rs @@ -56,7 +56,7 @@ pub fn mastiff_manysearch( let send_result = query_collection .par_iter() - .filter_map(|(_idx, record)| { + .filter_map(|(coll, _idx, record)| { let i = processed_sigs.fetch_add(1, atomic::Ordering::SeqCst); if i % 1000 == 0 && i > 0 { eprintln!("Processed {} search sigs", i); @@ -64,7 +64,7 @@ pub fn mastiff_manysearch( let mut results = vec![]; // query downsample happens here - match query_collection.sig_from_record(record) { + match coll.sig_from_record(record) { Ok(query_sig) => { if let Some(query_mh) = query_sig.minhash() { let query_size = query_mh.size(); diff --git a/src/multisearch.rs b/src/multisearch.rs index 19d2264d..17f8dfaf 100644 --- a/src/multisearch.rs +++ b/src/multisearch.rs @@ -60,6 +60,16 @@ pub fn multisearch( let processed_cmp = AtomicUsize::new(0); let ksize = selection.ksize().unwrap() as f64; + if queries.is_empty() { + eprintln!("No query sketches present. Exiting."); + return Err(anyhow::anyhow!("failed to load query sketches").into()); + } + + if against.is_empty() { + eprintln!("No search sketches present. Exiting."); + return Err(anyhow::anyhow!("failed to load search sketches").into()); + } + let send = against .par_iter() .filter_map(|against| { diff --git a/src/python/sourmash_plugin_branchwater/__init__.py b/src/python/sourmash_plugin_branchwater/__init__.py index 4b053159..290ed255 100755 --- a/src/python/sourmash_plugin_branchwater/__init__.py +++ b/src/python/sourmash_plugin_branchwater/__init__.py @@ -3,10 +3,12 @@ import argparse from sourmash.plugins import CommandLinePlugin from sourmash.logging import notify +from sourmash.exceptions import IndexNotLoaded import os import importlib.metadata from . import sourmash_plugin_branchwater +from . import sourmash_plugin_branchwater as api from . import prettyprint __version__ = importlib.metadata.version("sourmash_plugin_branchwater") @@ -37,6 +39,51 @@ def set_thread_pool(user_cores): return actual_rayon_cores +class BranchwaterManifestWrapper: + def __init__(self, mf_obj): + self.obj = mf_obj + + def _check_row_values(self): + return self.obj._check_row_values() + + @property + def rows(self): + return self.obj.rows + + +class BranchwaterCollectionWrapper: + def __init__(self, coll_obj): + self.obj = coll_obj + + @property + def location(self): + return self.obj.location + + @property + def is_database(self): + return self.obj.is_database + + @property + def has_manifest(self): + return self.obj.has_manifest + + @property + def manifest(self): + return BranchwaterManifestWrapper(self.obj.manifest) + + def __len__(self): + return len(self.obj) + + +def load_collection(path, *, traverse_yield_all=False, cache_size=0): + try: + coll_obj = api.api_load_collection(path, 31, 100_000, 'DNA') + return BranchwaterCollectionWrapper(coll_obj) + except: + raise IndexNotLoaded(f"branchwater could not load '{path}'") +load_collection.priority = 20 + + class Branchwater_Manysearch(CommandLinePlugin): command = 'manysearch' description = 'search many metagenomes for contained genomes' @@ -69,7 +116,6 @@ def __init__(self, p): def main(self, args): print_version() notify(f"ksize: {args.ksize} / scaled: {args.scaled} / moltype: {args.moltype} / threshold: {args.threshold}") - args.moltype = args.moltype.lower() num_threads = set_thread_pool(args.cores) notify(f"searching all sketches in '{args.query_paths}' against '{args.against_paths}' using {num_threads} threads") @@ -117,7 +163,6 @@ def __init__(self, p): def main(self, args): print_version() notify(f"ksize: {args.ksize} / scaled: {args.scaled} / moltype: {args.moltype} / threshold bp: {args.threshold_bp}") - args.moltype = args.moltype.lower() num_threads = set_thread_pool(args.cores) @@ -165,7 +210,6 @@ def __init__(self, p): def main(self, args): print_version() notify(f"ksize: {args.ksize} / scaled: {args.scaled} / moltype: {args.moltype} / threshold bp: {args.threshold_bp} / save matches: {args.save_matches}") - args.moltype = args.moltype.lower() num_threads = set_thread_pool(args.cores) @@ -212,7 +256,6 @@ def __init__(self, p): def main(self, args): notify(f"ksize: {args.ksize} / scaled: {args.scaled} / moltype: {args.moltype} ") - args.moltype = args.moltype.lower() num_threads = set_thread_pool(args.cores) @@ -277,7 +320,6 @@ def __init__(self, p): def main(self, args): print_version() notify(f"ksize: {args.ksize} / scaled: {args.scaled} / moltype: {args.moltype} / threshold: {args.threshold}") - args.moltype = args.moltype.lower() num_threads = set_thread_pool(args.cores) @@ -324,7 +366,6 @@ def __init__(self, p): def main(self, args): print_version() notify(f"ksize: {args.ksize} / scaled: {args.scaled} / moltype: {args.moltype} / threshold: {args.threshold}") - args.moltype = args.moltype.lower() num_threads = set_thread_pool(args.cores) diff --git a/src/python/tests/conftest.py b/src/python/tests/conftest.py index 052837f6..3f7021a1 100644 --- a/src/python/tests/conftest.py +++ b/src/python/tests/conftest.py @@ -16,6 +16,10 @@ def toggle_internal_storage(request): def zip_query(request): return request.param +@pytest.fixture(params=[True, False]) +def zip_db(request): + return request.param + @pytest.fixture(params=[True, False]) def zip_against(request): return request.param diff --git a/src/python/tests/sourmash_tst_utils.py b/src/python/tests/sourmash_tst_utils.py index 7c99b1b6..f4ad4927 100644 --- a/src/python/tests/sourmash_tst_utils.py +++ b/src/python/tests/sourmash_tst_utils.py @@ -14,6 +14,23 @@ from io import StringIO +def get_test_data(filename): + thisdir = os.path.dirname(__file__) + return os.path.join(thisdir, 'test-data', filename) + + +def make_file_list(filename, paths): + with open(filename, 'wt') as fp: + fp.write("\n".join(paths)) + fp.write("\n") + + +def zip_siglist(runtmp, siglist, db): + runtmp.sourmash('sig', 'cat', siglist, + '-o', db) + return db + + def scriptpath(scriptname='sourmash'): """Return the path to the scripts, in both dev and install situations.""" # note - it doesn't matter what the scriptname is here, as long as diff --git a/src/python/tests/test_cluster.py b/src/python/tests/test_cluster.py index 6e153946..4ae12173 100644 --- a/src/python/tests/test_cluster.py +++ b/src/python/tests/test_cluster.py @@ -2,15 +2,7 @@ import pytest from . import sourmash_tst_utils as utils - -def get_test_data(filename): - thisdir = os.path.dirname(__file__) - return os.path.join(thisdir, 'test-data', filename) - -def make_file_list(filename, paths): - with open(filename, 'wt') as fp: - fp.write("\n".join(paths)) - fp.write("\n") +from .sourmash_tst_utils import get_test_data, make_file_list def test_installed(runtmp): diff --git a/src/python/tests/test_gather.py b/src/python/tests/test_fastgather.py similarity index 96% rename from src/python/tests/test_gather.py rename to src/python/tests/test_fastgather.py index 4ab4c6de..90d22786 100644 --- a/src/python/tests/test_gather.py +++ b/src/python/tests/test_fastgather.py @@ -4,23 +4,7 @@ import sourmash from . import sourmash_tst_utils as utils - - -def get_test_data(filename): - thisdir = os.path.dirname(__file__) - return os.path.join(thisdir, 'test-data', filename) - - -def make_file_list(filename, paths): - with open(filename, 'wt') as fp: - fp.write("\n".join(paths)) - fp.write("\n") - - -def zip_siglist(runtmp, siglist, db): - runtmp.sourmash('sig', 'cat', siglist, - '-o', db) - return db +from .sourmash_tst_utils import (get_test_data, make_file_list, zip_siglist) def test_installed(runtmp): @@ -30,7 +14,6 @@ def test_installed(runtmp): assert 'usage: fastgather' in runtmp.last_result.err -@pytest.mark.parametrize('zip_against', [False, True]) def test_simple(runtmp, zip_against): # test basic execution! query = get_test_data('SRR606249.sig.gz') @@ -58,7 +41,6 @@ def test_simple(runtmp, zip_against): assert {'query_filename', 'query_name', 'query_md5', 'match_name', 'match_md5', 'gather_result_rank', 'intersect_bp'}.issubset(keys) -@pytest.mark.parametrize('zip_against', [False, True]) def test_simple_with_prefetch(runtmp, zip_against): # test basic execution! query = get_test_data('SRR606249.sig.gz') @@ -93,7 +75,6 @@ def test_simple_with_prefetch(runtmp, zip_against): assert keys == {'query_filename', 'query_name', 'query_md5', 'match_name', 'match_md5', 'intersect_bp'} -@pytest.mark.parametrize('zip_against', [False, True]) def test_missing_query(runtmp, capfd, zip_against): # test missing query query = runtmp.output('no-such-file') @@ -122,7 +103,6 @@ def test_missing_query(runtmp, capfd, zip_against): assert 'Error: No such file or directory' in captured.err -@pytest.mark.parametrize('zip_against', [False, True]) def test_bad_query(runtmp, capfd, zip_against): # test non-sig query query = runtmp.output('no-such-file') @@ -154,7 +134,6 @@ def test_bad_query(runtmp, capfd, zip_against): assert 'Error: Fastgather requires a single query sketch. Check input:' in captured.err -@pytest.mark.parametrize('zip_against', [False, True]) def test_missing_against(runtmp, capfd, zip_against): # test missing against query = get_test_data('SRR606249.sig.gz') @@ -278,7 +257,6 @@ def test_bad_against_3(runtmp, capfd): assert 'InvalidArchive' in captured.err -@pytest.mark.parametrize('zip_against', [False, True]) def test_against_multisigfile(runtmp, zip_against): # test against a sigfile that contains multiple sketches query = get_test_data('SRR606249.sig.gz') @@ -311,7 +289,6 @@ def test_against_multisigfile(runtmp, zip_against): # @CTB this is a bug :(. It should load multiple sketches properly! -@pytest.mark.parametrize('zip_against', [False, True]) def test_query_multisigfile(runtmp, capfd, zip_against): # test with a sigfile that contains multiple sketches against_list = runtmp.output('against.txt') @@ -341,7 +318,6 @@ def test_query_multisigfile(runtmp, capfd, zip_against): assert "Error: Fastgather requires a single query sketch. Check input:" in captured.err -@pytest.mark.parametrize('zip_against', [False, True]) def test_against_nomatch(runtmp, capfd, zip_against): # test with 'against' file containing a non-matching ksize query = get_test_data('SRR606249.sig.gz') @@ -370,7 +346,6 @@ def test_against_nomatch(runtmp, capfd, zip_against): assert 'WARNING: skipped 1 search paths - no compatible signatures.' in captured.err -@pytest.mark.parametrize('zip_against', [False, True]) def test_md5s(runtmp, zip_against): # check that the correct md5sums (of the original sketches) are in # the output files @@ -424,7 +399,6 @@ def test_md5s(runtmp, zip_against): assert ss.md5sum() in md5s -@pytest.mark.parametrize('zip_against', [False, True]) def test_csv_columns_vs_sourmash_prefetch(runtmp, zip_against): # the column names should be strict subsets of sourmash prefetch cols query = get_test_data('SRR606249.sig.gz') @@ -466,7 +440,6 @@ def test_csv_columns_vs_sourmash_prefetch(runtmp, zip_against): assert diff_keys == set(['unique_intersect_bp', 'median_abund', 'f_match_orig', 'std_abund', 'average_abund', 'f_unique_to_query', 'remaining_bp', 'f_unique_weighted', 'sum_weighted_found', 'total_weighted_hashes', 'n_unique_weighted_found', 'f_orig_query', 'f_match']) -@pytest.mark.parametrize('zip_against', [False, True]) def test_fastgather_gatherout_as_picklist(runtmp, zip_against): # should be able to use fastgather gather output as picklist query = get_test_data('SRR606249.sig.gz') @@ -508,7 +481,6 @@ def test_fastgather_gatherout_as_picklist(runtmp, zip_against): assert picklist_df.equals(full_df) -@pytest.mark.parametrize('zip_against', [False, True]) def test_fastgather_prefetchout_as_picklist(runtmp, zip_against): # should be able to use fastgather prefetch output as picklist query = get_test_data('SRR606249.sig.gz') @@ -632,7 +604,8 @@ def test_simple_hp(runtmp): def test_indexed_against(runtmp, capfd): - # do not accept rocksdb for now + return + # do not accept rocksdb for now @CTB we do now!! query = get_test_data('SRR606249.sig.gz') against_list = runtmp.output('against.txt') diff --git a/src/python/tests/test_multigather.py b/src/python/tests/test_fastmultigather.py similarity index 99% rename from src/python/tests/test_multigather.py rename to src/python/tests/test_fastmultigather.py index 831b5096..23f9cc19 100644 --- a/src/python/tests/test_multigather.py +++ b/src/python/tests/test_fastmultigather.py @@ -8,17 +8,7 @@ import sourmash from . import sourmash_tst_utils as utils - - -def get_test_data(filename): - thisdir = os.path.dirname(__file__) - return os.path.join(thisdir, 'test-data', filename) - - -def make_file_list(filename, paths): - with open(filename, 'wt') as fp: - fp.write("\n".join(paths)) - fp.write("\n") +from .sourmash_tst_utils import (get_test_data, make_file_list, zip_siglist) def index_siglist(runtmp, siglist, db, *, ksize=31, scaled=1000, moltype='DNA', @@ -37,11 +27,6 @@ def test_installed(runtmp): assert 'usage: fastmultigather' in runtmp.last_result.err -def zip_siglist(runtmp, siglist, db): - runtmp.sourmash('sig', 'cat', siglist, - '-o', db) - return db - def test_simple(runtmp, zip_against): # test basic execution! query = get_test_data('SRR606249.sig.gz') diff --git a/src/python/tests/test_index.py b/src/python/tests/test_index.py index 69faf8ae..140fe799 100644 --- a/src/python/tests/test_index.py +++ b/src/python/tests/test_index.py @@ -5,17 +5,7 @@ import shutil from . import sourmash_tst_utils as utils - - -def get_test_data(filename): - thisdir = os.path.dirname(__file__) - return os.path.join(thisdir, 'test-data', filename) - - -def make_file_list(filename, paths): - with open(filename, 'wt') as fp: - fp.write("\n".join(paths)) - fp.write("\n") +from .sourmash_tst_utils import (get_test_data, make_file_list, zip_siglist) def test_installed(runtmp): diff --git a/src/python/tests/test_index_api.py b/src/python/tests/test_index_api.py new file mode 100644 index 00000000..a146b65d --- /dev/null +++ b/src/python/tests/test_index_api.py @@ -0,0 +1,39 @@ +import sourmash_plugin_branchwater as branch +from . import sourmash_tst_utils as utils +from .sourmash_tst_utils import get_test_data + + +def test_basic(): + sigfile = get_test_data('SRR606249.sig.gz') + res = branch.api.api_load_collection(sigfile, 31, 100_000, 'DNA') + assert res.location == sigfile + assert len(res) == 1 + + +def test_fail(): + # try to load a (nonexistent) collection + sigfile = get_test_data('XXX_SRR606249.sig.gz') + try: + res = branch.api.api_load_collection(sigfile, 31, 100_000, 'DNA') + except: + pass + # @CTB should do something better here ;) + + +def test_basic_get_manifest(): + sigfile = get_test_data('SRR606249.sig.gz') + res = branch.api.api_load_collection(sigfile, 31, 100_000, 'DNA') + mf = res.manifest + print(mf, dir(mf)) + assert len(mf) == 1 + + rec = res.get_first_record() + print(rec, dir(rec)) + print('ZZZ', rec.as_row) + + print(rec.get_name()) + + print(mf.rows) + for rec in mf.rows: + print(rec.get_name()) + assert 0 diff --git a/src/python/tests/test_manysearch.py b/src/python/tests/test_manysearch.py index 6deb5c3b..ab0f5762 100644 --- a/src/python/tests/test_manysearch.py +++ b/src/python/tests/test_manysearch.py @@ -4,17 +4,7 @@ import sourmash from . import sourmash_tst_utils as utils - - -def get_test_data(filename): - thisdir = os.path.dirname(__file__) - return os.path.join(thisdir, 'test-data', filename) - - -def make_file_list(filename, paths): - with open(filename, 'wt') as fp: - fp.write("\n".join(paths)) - fp.write("\n") +from .sourmash_tst_utils import (get_test_data, make_file_list, zip_siglist) def test_installed(runtmp): @@ -23,10 +13,6 @@ def test_installed(runtmp): assert 'usage: manysearch' in runtmp.last_result.err -def zip_siglist(runtmp, siglist, db): - runtmp.sourmash('sig', 'cat', siglist, - '-o', db) - return db def index_siglist(runtmp, siglist, db, ksize=31, scaled=1000, moltype='DNA'): # build index @@ -35,8 +21,6 @@ def index_siglist(runtmp, siglist, db, ksize=31, scaled=1000, moltype='DNA'): '--moltype', moltype) return db -@pytest.mark.parametrize("zip_query", [False, True]) -@pytest.mark.parametrize("zip_against", [False, True]) def test_simple(runtmp, zip_query, zip_against): # test basic execution! query_list = runtmp.output('query.txt') @@ -192,7 +176,6 @@ def test_simple_abund(runtmp): assert total_weighted_hashes == 73489 -@pytest.mark.parametrize("zip_query", [False, True]) def test_simple_indexed(runtmp, zip_query): # test basic execution! query_list = runtmp.output('query.txt') @@ -249,8 +232,6 @@ def test_simple_indexed(runtmp, zip_query): assert query_ani == 0.9772 -@pytest.mark.parametrize("indexed", [False, True]) -@pytest.mark.parametrize("zip_query", [False, True]) def test_simple_with_cores(runtmp, capfd, indexed, zip_query): # test basic execution with -c argument (that it runs, at least!) query_list = runtmp.output('query.txt') @@ -283,8 +264,6 @@ def test_simple_with_cores(runtmp, capfd, indexed, zip_query): assert " using 4 threads" in result.err -@pytest.mark.parametrize("indexed", [False, True]) -@pytest.mark.parametrize("zip_query", [False, True]) def test_simple_threshold(runtmp, indexed, zip_query): # test with a simple threshold => only 3 results query_list = runtmp.output('query.txt') @@ -313,7 +292,6 @@ def test_simple_threshold(runtmp, indexed, zip_query): assert len(df) == 3 -@pytest.mark.parametrize("indexed", [False, True]) def test_simple_manifest(runtmp, indexed): # test with a simple threshold => only 3 results query_list = runtmp.output('query.txt') @@ -347,8 +325,6 @@ def test_simple_manifest(runtmp, indexed): assert len(df) == 3 -@pytest.mark.parametrize("indexed", [False, True]) -@pytest.mark.parametrize("zip_query", [False, True]) def test_missing_query(runtmp, capfd, indexed, zip_query): # test with a missing query list query_list = runtmp.output('query.txt') @@ -379,7 +355,6 @@ def test_missing_query(runtmp, capfd, indexed, zip_query): assert 'Error: No such file or directory' in captured.err -@pytest.mark.parametrize("indexed", [False, True]) def test_sig_query(runtmp, capfd, indexed): # test with a single sig query (a .sig.gz file) against_list = runtmp.output('against.txt') @@ -399,7 +374,6 @@ def test_sig_query(runtmp, capfd, indexed): '-o', output) -@pytest.mark.parametrize("indexed", [False, True]) def test_bad_query_2(runtmp, capfd, indexed): # test with a bad query list (a missing file) query_list = runtmp.output('query.txt') @@ -453,7 +427,6 @@ def test_bad_query_3(runtmp, capfd): assert 'InvalidArchive' in captured.err -@pytest.mark.parametrize("indexed", [False, True]) def test_missing_against(runtmp, capfd, indexed): # test with a missing against list query_list = runtmp.output('query.txt') @@ -524,7 +497,6 @@ def test_bad_against(runtmp, capfd): assert "WARNING: 1 search paths failed to load. See error messages above." in captured.err -@pytest.mark.parametrize("indexed", [False, True]) def test_empty_query(runtmp, indexed, capfd): # test with an empty query list query_list = runtmp.output('query.txt') @@ -552,8 +524,6 @@ def test_empty_query(runtmp, indexed, capfd): assert "No query signatures loaded, exiting." in captured.err -@pytest.mark.parametrize("indexed", [False, True]) -@pytest.mark.parametrize("zip_query", [False, True]) def test_nomatch_query(runtmp, capfd, indexed, zip_query): # test a non-matching (diff ksize) in query; do we get warning message? query_list = runtmp.output('query.txt') @@ -584,8 +554,6 @@ def test_nomatch_query(runtmp, capfd, indexed, zip_query): assert 'WARNING: skipped 1 query paths - no compatible signatures.' in captured.err -@pytest.mark.parametrize("zip_against", [False, True]) -@pytest.mark.parametrize("indexed", [False, True]) def test_load_only_one_bug(runtmp, capfd, indexed, zip_against): # check that we behave properly when presented with multiple against # sketches @@ -619,8 +587,6 @@ def test_load_only_one_bug(runtmp, capfd, indexed, zip_against): assert not 'WARNING: no compatible sketches in path ' in captured.err -@pytest.mark.parametrize("zip_query", [False, True]) -@pytest.mark.parametrize("indexed", [False, True]) def test_load_only_one_bug_as_query(runtmp, capfd, indexed, zip_query): # check that we behave properly when presented with multiple query # sketches in one file, with only one matching. @@ -656,8 +622,6 @@ def test_load_only_one_bug_as_query(runtmp, capfd, indexed, zip_query): assert not 'WARNING: no compatible sketches in path ' in captured.err -@pytest.mark.parametrize("zip_query", [False, True]) -@pytest.mark.parametrize("indexed", [False, True]) def test_md5(runtmp, indexed, zip_query): # test that md5s match what was in the original files, not downsampled etc. query_list = runtmp.output('query.txt') diff --git a/src/python/tests/test_multisearch.py b/src/python/tests/test_multisearch.py index 611b0f81..87553615 100644 --- a/src/python/tests/test_multisearch.py +++ b/src/python/tests/test_multisearch.py @@ -5,17 +5,7 @@ import sourmash from . import sourmash_tst_utils as utils - - -def get_test_data(filename): - thisdir = os.path.dirname(__file__) - return os.path.join(thisdir, 'test-data', filename) - - -def make_file_list(filename, paths): - with open(filename, 'wt') as fp: - fp.write("\n".join(paths)) - fp.write("\n") +from .sourmash_tst_utils import (get_test_data, make_file_list, zip_siglist) def test_installed(runtmp): @@ -24,13 +14,7 @@ def test_installed(runtmp): assert 'usage: multisearch' in runtmp.last_result.err -def zip_siglist(runtmp, siglist, db): - runtmp.sourmash('sig', 'cat', siglist, - '-o', db) - return db -@pytest.mark.parametrize("zip_query", [False, True]) -@pytest.mark.parametrize("zip_db", [False, True]) def test_simple_no_ani(runtmp, zip_query, zip_db): # test basic execution! query_list = runtmp.output('query.txt') @@ -99,8 +83,6 @@ def test_simple_no_ani(runtmp, zip_query, zip_db): assert intersect_hashes == 2529 -@pytest.mark.parametrize("zip_query", [False, True]) -@pytest.mark.parametrize("zip_db", [False, True]) def test_simple_ani(runtmp, zip_query, zip_db): # test basic execution! query_list = runtmp.output('query.txt') @@ -186,8 +168,6 @@ def test_simple_ani(runtmp, zip_query, zip_db): assert max_ani == 0.9772 -@pytest.mark.parametrize("zip_query", [False, True]) -@pytest.mark.parametrize("zip_db", [False, True]) def test_simple_threshold(runtmp, zip_query, zip_db): # test with a simple threshold => only 3 results query_list = runtmp.output('query.txt') @@ -243,7 +223,6 @@ def test_simple_manifest(runtmp): assert len(df) == 3 -@pytest.mark.parametrize("zip_query", [False, True]) def test_missing_query(runtmp, capfd, zip_query): # test with a missing query list query_list = runtmp.output('query.txt') @@ -344,7 +323,6 @@ def test_bad_query_3(runtmp, capfd): assert 'InvalidArchive' in captured.err -@pytest.mark.parametrize("zip_db", [False, True]) def test_missing_against(runtmp, capfd, zip_db): # test with a missing against list query_list = runtmp.output('query.txt') @@ -445,7 +423,6 @@ def test_empty_query(runtmp, capfd): # @CTB -@pytest.mark.parametrize("zip_query", [False, True]) def test_nomatch_query(runtmp, capfd, zip_query): # test a non-matching (diff ksize) in query; do we get warning message? query_list = runtmp.output('query.txt') @@ -474,7 +451,6 @@ def test_nomatch_query(runtmp, capfd, zip_query): assert 'WARNING: skipped 1 query paths - no compatible signatures' in captured.err -@pytest.mark.parametrize("zip_db", [False, True]) def test_load_only_one_bug(runtmp, capfd, zip_db): # check that we behave properly when presented with multiple against # sketches @@ -506,7 +482,6 @@ def test_load_only_one_bug(runtmp, capfd, zip_db): assert not 'WARNING: no compatible sketches in path' in captured.err -@pytest.mark.parametrize("zip_query", [False, True]) def test_load_only_one_bug_as_query(runtmp, capfd, zip_query): # check that we behave properly when presented with multiple query # sketches in one file, with only one matching. @@ -538,8 +513,6 @@ def test_load_only_one_bug_as_query(runtmp, capfd, zip_query): assert not 'WARNING: no compatible sketches in path ' in captured.err -@pytest.mark.parametrize("zip_query", [False, True]) -@pytest.mark.parametrize("zip_db", [False, True]) def test_md5(runtmp, zip_query, zip_db): # test that md5s match what was in the original files, not downsampled etc. query_list = runtmp.output('query.txt') diff --git a/src/python/tests/test_pairwise.py b/src/python/tests/test_pairwise.py index 3869b3d4..c8264069 100644 --- a/src/python/tests/test_pairwise.py +++ b/src/python/tests/test_pairwise.py @@ -5,17 +5,7 @@ import sourmash from . import sourmash_tst_utils as utils - - -def get_test_data(filename): - thisdir = os.path.dirname(__file__) - return os.path.join(thisdir, 'test-data', filename) - - -def make_file_list(filename, paths): - with open(filename, 'wt') as fp: - fp.write("\n".join(paths)) - fp.write("\n") +from .sourmash_tst_utils import (get_test_data, make_file_list, zip_siglist) def test_installed(runtmp): @@ -24,13 +14,7 @@ def test_installed(runtmp): assert 'usage: pairwise' in runtmp.last_result.err -def zip_siglist(runtmp, siglist, db): - runtmp.sourmash('sig', 'cat', siglist, - '-o', db) - return db - -@pytest.mark.parametrize("zip_query", [False, True]) def test_simple_no_ani(runtmp, zip_query): # test basic execution! query_list = runtmp.output('query.txt') @@ -81,7 +65,6 @@ def test_simple_no_ani(runtmp, zip_query): assert intersect_hashes == 2529 -@pytest.mark.parametrize("zip_query", [False, True]) def test_simple_ani(runtmp, zip_query): # test basic execution! query_list = runtmp.output('query.txt') @@ -140,7 +123,6 @@ def test_simple_ani(runtmp, zip_query): assert max_ani == 0.9772 -@pytest.mark.parametrize("zip_query", [False, True]) def test_simple_threshold(runtmp, zip_query): # test with a simple threshold => only 3 results query_list = runtmp.output('query.txt') @@ -248,7 +230,6 @@ def test_bad_query_2(runtmp, capfd): assert 'InvalidArchive' in captured.err -@pytest.mark.parametrize("zip_db", [False, True]) def test_missing_query(runtmp, capfd, zip_db): # test with a missing query list query_list = runtmp.output('query.txt') @@ -290,7 +271,6 @@ def test_empty_query(runtmp): # @CTB -@pytest.mark.parametrize("zip_query", [False, True]) def test_nomatch_query(runtmp, capfd, zip_query): # test a non-matching (diff ksize) in query; do we get warning message? query_list = runtmp.output('query.txt') @@ -317,7 +297,6 @@ def test_nomatch_query(runtmp, capfd, zip_query): assert 'WARNING: skipped 1 analysis paths - no compatible signatures' in captured.err -@pytest.mark.parametrize("zip_db", [False, True]) def test_load_only_one_bug(runtmp, capfd, zip_db): # check that we behave properly when presented with multiple query # sketches @@ -347,7 +326,6 @@ def test_load_only_one_bug(runtmp, capfd, zip_db): assert not 'WARNING: no compatible sketches in path ' in captured.err -@pytest.mark.parametrize("zip_query", [False, True]) def test_md5(runtmp, zip_query): # test that md5s match what was in the original files, not downsampled etc. query_list = runtmp.output('query.txt') diff --git a/src/utils.rs b/src/utils/mod.rs similarity index 86% rename from src/utils.rs rename to src/utils/mod.rs index 4209413e..1e06d31e 100644 --- a/src/utils.rs +++ b/src/utils/mod.rs @@ -1,9 +1,9 @@ -/// Utility functions for sourmash_plugin_branchwater. +//! Utility functions for `sourmash_plugin_branchwater`. use rayon::prelude::*; use sourmash::encodings::HashFunctions; use sourmash::selection::Select; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, Result}; use camino::Utf8Path as Path; use camino::Utf8PathBuf as PathBuf; use csv::Writer; @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use std::cmp::{Ordering, PartialOrd}; use std::collections::BinaryHeap; use std::fs::{create_dir_all, File}; -use std::io::{BufRead, BufReader, BufWriter, Write}; +use std::io::{BufWriter, Write}; use std::panic; use std::sync::atomic; use std::sync::atomic::AtomicUsize; @@ -20,24 +20,18 @@ use zip::write::{ExtendedFileOptions, FileOptions, ZipWriter}; use zip::CompressionMethod; use sourmash::ani_utils::{ani_ci_from_containment, ani_from_containment}; -use sourmash::collection::Collection; use sourmash::manifest::{Manifest, Record}; use sourmash::selection::Selection; use sourmash::signature::{Signature, SigsTrait}; use sourmash::sketch::minhash::KmerMinHash; -use sourmash::storage::{FSStorage, InnerStorage, SigStore}; +use sourmash::storage::SigStore; use stats::{median, stddev}; use std::collections::{HashMap, HashSet}; -/// Track a name/minhash. -pub struct SmallSignature { - pub location: String, - pub name: String, - pub md5sum: String, - pub minhash: KmerMinHash, -} -/// Structure to hold overlap information from comparisons. +pub mod multicollection; +use multicollection::{MultiCollection, SmallSignature}; +/// Structure to hold overlap information from comparisons. pub struct PrefetchResult { pub name: String, pub md5sum: String, @@ -432,26 +426,37 @@ fn process_prefix_csv( Ok((results, n_fastas)) } -// Load all compatible minhashes from a collection into memory +///////// + +// Load all compatible minhashes from a collection into memory, in parallel; // also store sig name and md5 alongside, as we usually need those pub fn load_sketches( - collection: Collection, + multi: MultiCollection, selection: &Selection, _report_type: ReportType, ) -> Result> { - let sketchinfo: Vec = collection + let sketchinfo: Vec<_> = multi .par_iter() - .filter_map(|(_idx, record)| { - let sig = collection.sig_from_record(record).ok()?; - let selected_sig = sig.clone().select(selection).ok()?; - let minhash = selected_sig.minhash()?.clone(); - - Some(SmallSignature { - location: record.internal_location().to_string(), - name: sig.name(), - md5sum: sig.md5sum(), - minhash, - }) + .filter_map(|(coll, _idx, record)| match coll.sig_from_record(record) { + Ok(sig) => { + let selected_sig = sig.clone().select(selection).ok()?; + let minhash = selected_sig.minhash()?.clone(); + + Some(SmallSignature { + collection: coll.clone(), // @CTB + location: record.internal_location().to_string(), + name: sig.name(), + md5sum: sig.md5sum(), + minhash, + }) + } + Err(_) => { + eprintln!( + "FAILED to load sketch from '{}'", + record.internal_location() + ); + None + } }) .collect(); @@ -462,7 +467,7 @@ pub fn load_sketches( /// those with a minimum overlap. pub fn load_sketches_above_threshold( - against_collection: Collection, + against_collection: MultiCollection, query: &KmerMinHash, threshold_hashes: u64, ) -> Result<(BinaryHeap, usize, usize)> { @@ -471,10 +476,10 @@ pub fn load_sketches_above_threshold( let matchlist: BinaryHeap = against_collection .par_iter() - .filter_map(|(_idx, against_record)| { + .filter_map(|(coll, _idx, against_record)| { let mut results = Vec::new(); // Load against into memory - if let Ok(against_sig) = against_collection.sig_from_record(against_record) { + if let Ok(against_sig) = coll.sig_from_record(against_record) { if let Some(against_mh) = against_sig.minhash() { // downsample against_mh, but keep original md5sum let against_mh_ds = against_mh.downsample_scaled(query.scaled()).unwrap(); @@ -537,148 +542,25 @@ impl std::fmt::Display for ReportType { } } -pub fn collection_from_zipfile(sigpath: &Path, report_type: &ReportType) -> Result { - match Collection::from_zipfile(sigpath) { - Ok(collection) => Ok(collection), - Err(_) => bail!("failed to load {} zipfile: '{}'", report_type, sigpath), - } -} - -fn collection_from_manifest( - sigpath: &Path, - report_type: &ReportType, -) -> Result { - let file = File::open(sigpath) - .with_context(|| format!("Failed to open {} file: '{}'", report_type, sigpath))?; - - let reader = BufReader::new(file); - let manifest = Manifest::from_reader(reader).with_context(|| { - format!( - "Failed to read {} manifest from: '{}'", - report_type, sigpath - ) - })?; - - if manifest.is_empty() { - // If the manifest is empty, return an error constructed with the anyhow! macro - Err(anyhow!("could not read as manifest: '{}'", sigpath)) - } else { - // If the manifest is not empty, proceed to create and return the Collection - Ok(Collection::new( - manifest, - InnerStorage::new( - FSStorage::builder() - .fullpath("".into()) - .subdir("".into()) - .build(), - ), - )) - } -} - -fn collection_from_pathlist( - sigpath: &Path, - report_type: &ReportType, -) -> Result<(Collection, usize), anyhow::Error> { - let file = File::open(sigpath).with_context(|| { - format!( - "Failed to open {} pathlist file: '{}'", - report_type, sigpath - ) - })?; - let reader = BufReader::new(file); - - // load list of paths - let lines: Vec<_> = reader - .lines() - .filter_map(|line| match line { - Ok(path) => Some(path), - Err(_err) => None, - }) - .collect(); - - // load sketches from paths in parallel. - let n_failed = AtomicUsize::new(0); - let records: Vec = lines - .par_iter() - .filter_map(|path| match Signature::from_path(path) { - Ok(signatures) => { - let recs: Vec = signatures - .into_iter() - .flat_map(|v| Record::from_sig(&v, path)) - .collect(); - Some(recs) - } - Err(err) => { - eprintln!("Sketch loading error: {}", err); - eprintln!("WARNING: could not load sketches from path '{}'", path); - let _ = n_failed.fetch_add(1, atomic::Ordering::SeqCst); - None - } - }) - .flatten() - .collect(); - - if records.is_empty() { - eprintln!( - "No valid signatures found in {} pathlist '{}'", - report_type, sigpath - ); - } - - let manifest: Manifest = records.into(); - let collection = Collection::new( - manifest, - InnerStorage::new( - FSStorage::builder() - .fullpath("".into()) - .subdir("".into()) - .build(), - ), - ); - let n_failed = n_failed.load(atomic::Ordering::SeqCst); - - Ok((collection, n_failed)) -} - -fn collection_from_signature(sigpath: &Path, report_type: &ReportType) -> Result { - let signatures = Signature::from_path(sigpath).with_context(|| { - format!( - "Failed to load {} signatures from: '{}'", - report_type, sigpath - ) - })?; - - Collection::from_sigs(signatures).with_context(|| { - format!( - "Loaded {} signatures but failed to load as collection: '{}'", - report_type, sigpath - ) - }) -} +/// Load a multi collection from a path - this is the new top-level load function. pub fn load_collection( siglist: &String, selection: &Selection, report_type: ReportType, allow_failed: bool, -) -> Result { +) -> Result { let sigpath = PathBuf::from(siglist); if !sigpath.exists() { bail!("No such file or directory: '{}'", &sigpath); } - // disallow rocksdb input here - if is_revindex_database(&sigpath) { - bail!("Cannot load {} signatures from a 'rocksdb' database. Please use sig, zip, or pathlist.", report_type); - } - eprintln!("Reading {}(s) from: '{}'", report_type, &siglist); let mut last_error = None; let collection = if sigpath.extension().map_or(false, |ext| ext == "zip") { - match collection_from_zipfile(&sigpath, &report_type) { + match MultiCollection::from_zipfile(&sigpath) { Ok(coll) => Some((coll, 0)), Err(e) => { last_error = Some(e); @@ -689,32 +571,37 @@ pub fn load_collection( None }; - let collection = - collection.or_else(|| match collection_from_manifest(&sigpath, &report_type) { - Ok(coll) => Some((coll, 0)), - Err(e) => { - last_error = Some(e); - None - } - }); + let collection = collection.or_else(|| match MultiCollection::from_rocksdb(&sigpath) { + Ok(coll) => Some((coll, 0)), + Err(e) => { + last_error = Some(e); + None + } + }); - let collection = - collection.or_else(|| match collection_from_signature(&sigpath, &report_type) { - Ok(coll) => Some((coll, 0)), - Err(e) => { - last_error = Some(e); - None - } - }); + let collection = collection.or_else(|| match MultiCollection::from_standalone_manifest(&sigpath) { + Ok(coll) => Some((coll, 0)), + Err(e) => { + last_error = Some(e); + None + } + }); - let collection = - collection.or_else(|| match collection_from_pathlist(&sigpath, &report_type) { - Ok((coll, n_failed)) => Some((coll, n_failed)), - Err(e) => { - last_error = Some(e); - None - } - }); + let collection = collection.or_else(|| match MultiCollection::from_signature(&sigpath) { + Ok(coll) => Some((coll, 0)), + Err(e) => { + last_error = Some(e); + None + } + }); + + let collection = collection.or_else(|| match MultiCollection::from_pathlist(&sigpath) { + Ok((coll, n_failed)) => Some((coll, n_failed)), + Err(e) => { + last_error = Some(e); + None + } + }); match collection { Some((coll, n_failed)) => { @@ -766,7 +653,7 @@ pub fn load_collection( /// Returns an error if: /// * No signatures were successfully loaded. pub fn report_on_collection_loading( - collection: &Collection, + collection: &MultiCollection, skipped_paths: usize, failed_paths: usize, report_type: ReportType, @@ -966,7 +853,7 @@ pub fn consume_query_by_gather( } let query_md5sum: String = orig_query_mh.md5sum().clone(); let query_name = query.name().clone(); - let query_scaled = orig_query_mh.scaled().clone() as usize; //query_mh.scaled() as usize + let query_scaled = orig_query_mh.scaled() as usize; let mut query_mh = orig_query_mh.clone(); let mut orig_query_ds = orig_query_mh.clone().downsample_scaled(scaled)?; @@ -1037,11 +924,11 @@ pub fn consume_query_by_gather( query_filename: query.filename(), query_name: query_name.clone(), query_md5: query_md5sum.clone(), - query_bp: query_bp.clone(), + query_bp, ksize, moltype: query_moltype.clone(), - scaled: query_scaled.clone(), - query_n_hashes: query_n_hashes, + scaled: query_scaled, + query_n_hashes, query_abundance: query_mh.track_abundance(), query_containment_ani: match_.query_containment_ani, match_containment_ani: match_.match_containment_ani, @@ -1091,7 +978,7 @@ pub fn consume_query_by_gather( pub fn build_selection(ksize: u8, scaled: usize, moltype: &str) -> Selection { let hash_function = match moltype { - "dna" => HashFunctions::Murmur64Dna, + "DNA" => HashFunctions::Murmur64Dna, "protein" => HashFunctions::Murmur64Protein, "dayhoff" => HashFunctions::Murmur64Dayhoff, "hp" => HashFunctions::Murmur64Hp, diff --git a/src/utils/multicollection.rs b/src/utils/multicollection.rs new file mode 100644 index 00000000..02118a1d --- /dev/null +++ b/src/utils/multicollection.rs @@ -0,0 +1,228 @@ +//! MultiCollection implementation to handle sketches coming from multiple files. + +use rayon::prelude::*; + +use anyhow::{anyhow, Context, Result}; +use camino::Utf8Path as Path; +use log::debug; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::sync::atomic; +use std::sync::atomic::AtomicUsize; +use std::collections::HashSet; + +use sourmash::collection::{Collection, CollectionSet}; +use sourmash::encodings::Idx; +use sourmash::errors::SourmashError; +use sourmash::manifest::{Manifest, Record}; +use sourmash::selection::{Select, Selection}; +use sourmash::signature::Signature; +use sourmash::sketch::minhash::KmerMinHash; +use sourmash::storage::SigStore; + +/// A collection of sketches, potentially stored in multiple files. +pub struct MultiCollection { + collections: Vec, +} + +impl MultiCollection { + fn new(collections: Vec) -> Self { + Self { collections } + } + + // Turn a set of paths into list of Collections. + fn load_set_of_paths(paths: HashSet) -> (Vec, usize) { + let n_failed = AtomicUsize::new(0); + + let colls: Vec<_> = paths + .par_iter() + .filter_map(|iloc| match iloc { + // could just use a variant of load_collection here? + x if x.ends_with(".zip") => { + debug!("loading sigs from zipfile {}", x); + Some(Collection::from_zipfile(x).unwrap()) + }, + _ => { + debug!("loading sigs from sigfile {}", iloc); + let signatures = match Signature::from_path(iloc) { + Ok(signatures) => Some(signatures), + Err(err) => { + eprintln!("Sketch loading error: {}", err); + None + } + }; + + match signatures { + Some(signatures) => { + Some(Collection::from_sigs(signatures).unwrap()) + }, + None => { + eprintln!("WARNING: could not load sketches from path '{}'", iloc); + let _ = n_failed.fetch_add(1, atomic::Ordering::SeqCst); + None + } + } + } + }) + .collect(); + + let n_failed = n_failed.load(atomic::Ordering::SeqCst); + (colls, n_failed) + } + + /// Build from a standalone manifest + pub fn from_standalone_manifest(sigpath: &Path) -> Result { + debug!("multi from standalone manifest!"); + let file = + File::open(sigpath).with_context(|| format!("Failed to open file: '{}'", sigpath))?; + + let reader = BufReader::new(file); + let manifest = Manifest::from_reader(reader) + .with_context(|| format!("Failed to read manifest from: '{}'", sigpath))?; + + if manifest.is_empty() { + Err(anyhow!("could not read as manifest: '{}'", sigpath)) + } else { + let ilocs: HashSet<_> = manifest + .internal_locations() + .map(|s| String::from(s)) + .collect(); + + let (colls, _n_failed) = MultiCollection::load_set_of_paths(ilocs); + let colls = colls.into_iter().collect(); + + Ok(MultiCollection::new(colls)) + } + } + + /// Load a collection from a .zip file. + pub fn from_zipfile(sigpath: &Path) -> Result { + debug!("multi from zipfile!"); + match Collection::from_zipfile(sigpath) { + Ok(collection) => Ok(MultiCollection::new(vec![collection])), + Err(_) => bail!("failed to load zipfile: '{}'", sigpath), + } + } + + /// Load a collection from a RocksDB. + pub fn from_rocksdb(sigpath: &Path) -> Result { + debug!("multi from rocksdb!"); + match Collection::from_rocksdb(sigpath) { + Ok(collection) => Ok(MultiCollection::new(vec![collection])), + Err(_) => bail!("failed to load rocksdb: '{}'", sigpath), + } + } + + /// Load a collection from a list of paths. + pub fn from_pathlist(sigpath: &Path) -> Result<(Self, usize)> { + debug!("multi from pathlist!"); + let file = File::open(sigpath) + .with_context(|| format!("Failed to open pathlist file: '{}'", sigpath))?; + let reader = BufReader::new(file); + + // load set of paths + let lines: HashSet<_> = reader + .lines() + .filter_map(|line| match line { + Ok(path) => Some(path), + Err(_err) => None, + }) + .collect(); + + let (colls, n_failed) = MultiCollection::load_set_of_paths(lines); + let colls: Vec<_> = colls.into_iter().collect(); + + Ok((MultiCollection::new(colls), n_failed)) + } + + // Load from a sig file + pub fn from_signature(sigpath: &Path) -> Result { + debug!("multi from signature!"); + let signatures = Signature::from_path(sigpath) + .with_context(|| format!("Failed to load signatures from: '{}'", sigpath))?; + + let coll = Collection::from_sigs(signatures).with_context(|| { + format!( + "Loaded signatures but failed to load as collection: '{}'", + sigpath + ) + })?; + Ok(MultiCollection::new(vec![coll])) + } + + pub fn len(&self) -> usize { + let val: usize = self.collections.iter().map(|c| c.len()).sum(); + val + } + pub fn is_empty(&self) -> bool { + let val: usize = self.collections.iter().map(|c| c.len()).sum(); + val == 0 + } + + pub fn iter(&self) -> impl Iterator { + self.collections.iter() + } + + // iterate over tuples + pub fn item_iter(&self) -> impl Iterator { + // CTB: request review by Rust expert pls :). Does this make + // unnecessary copies?? + let s: Vec<_> = self + .iter() + .flat_map(|c| c.iter().map(move |(_idx, record)| (c, _idx, record))) + .collect(); + s.into_iter() + } + + pub fn par_iter(&self) -> impl IndexedParallelIterator { + // CTB: request review by Rust expert - why can't I use item_iter here? + // i.e. self.item_iter().into_par_iter()? + let s: Vec<_> = self + .iter() + .flat_map(|c| c.iter().map(move |(_idx, record)| (c, _idx, record))) + .collect(); + s.into_par_iter() + } + + pub fn get_first_sig(&self) -> Option { + if !self.is_empty() { + let query_item = self.item_iter().next().unwrap(); + let (coll, _, _) = query_item; + Some(coll.sig_for_dataset(0).ok()?) + } else { + None + } + } +} + +impl Select for MultiCollection { + fn select(mut self, selection: &Selection) -> Result { + // CTB: request review by Rust expert! Is the clone necessary? + self.collections = self + .iter() + .filter_map(|c| c.clone().select(selection).ok()) + .collect(); + Ok(self) + } +} + +impl TryFrom for CollectionSet { + type Error = SourmashError; + + fn try_from(multi: MultiCollection) -> Result { + // CTB: request review by Rust expert! Is the clone necessary? + let coll = multi.iter().next().unwrap().clone(); + let cs: CollectionSet = coll.try_into()?; + Ok(cs) + } +} + +/// Track a name/minhash. +pub struct SmallSignature { + // CTB: request help - can we/should we use references & lifetimes here? + pub collection: Collection, + pub location: String, + pub name: String, + pub md5sum: String, + pub minhash: KmerMinHash, +}