diff --git a/src/fastagather.rs b/src/fastagather.rs index cec45146..6bec94bd 100644 --- a/src/fastagather.rs +++ b/src/fastagather.rs @@ -1,17 +1,26 @@ use crate::utils::buildutils::BuildCollection; use anyhow::{bail, Result}; +use needletail::parse_fastx_file; +use sourmash::selection::Selection; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crate::utils::{ + consume_query_by_gather, load_collection, load_sketches_above_threshold, write_prefetch, + ReportType, +}; + #[allow(clippy::too_many_arguments)] pub fn fastagather( - input_filename: String, + query_filename: String, + against_filepath: String, input_moltype: String, threshold_bp: u64, selection: &Selection, - gather_output: Option, prefetch_output: Option, + gather_output: Option, allow_failed_sigpaths: bool, ) -> Result<()> { - // to start, implement straightforward record --> sketch --> gather // other ideas: // - add full-file (lower resolution) prefetch first, to reduce search space @@ -19,22 +28,23 @@ pub fn fastagather( // Build signature templates based on parsed parameters let sig_template_result = BuildCollection::from_selection(selection); - let mut sig_templates = match sig_template_result { - Ok(sig_templates) => sig_templates, + let mut sig_template = match sig_template_result { + Ok(sig_template) => sig_template, Err(e) => { bail!("Failed to build template signatures: {}", e); } }; - if sigs.is_empty() { - bail!("No signatures to build for the given parameters."); + if sig_template.size() != 1 { + bail!("FASTAgather requires a single signature type for search."); } let input_moltype = input_moltype.to_ascii_lowercase(); let mut against_selection = selection; - let scaled = query_mh.scaled(); - against_selection.set_scaled(scaled); + // get scaled from selection here + let scaled = selection.scaled().unwrap(); // rm this unwrap? + against_selection.set_scaled(scaled as u32); // calculate the minimum number of hashes based on desired threshold let threshold_hashes = { @@ -53,41 +63,44 @@ pub fn fastagather( ReportType::Against, allow_failed_sigpaths, )?; - + let failed_records = AtomicUsize::new(0); // open file and start iterating through sequences // Open fasta file reader - let mut reader = match parse_fastx_file(filename) { + let mut reader = match parse_fastx_file(query_filename.clone()) { Ok(r) => r, Err(err) => { - bail!("Error opening file {}: {:?}", filename, err); + bail!("Error opening file {}: {:?}", query_filename, err); } }; - + // later: can we parallelize across records or sigs? Do we want to batch groups of records for improved gather efficiency? while let Some(record_result) = reader.next() { - // clone sig_templates for use - sigs = sig_templates.clone(); + // clone sig_templates for use + let sigcoll = sig_template.clone(); match record_result { Ok(record) => { - if let Err(err) = sigs.build_singleton_sigs( - record, - input_moltype, - filename.to_string(), - ) { + if let Err(err) = + sigcoll.build_singleton_sigs(record, &input_moltype, query_filename.clone()) + { eprintln!( "Error building signatures from file: {}, {:?}", - filename, err + query_filename, err ); - failed_records.fetch_add(1, atomic::Ordering::SeqCst); + failed_records.fetch_add(1, Ordering::SeqCst); } - for (rec, query_sig) in sigs.iter(){ + // in each iteration, this should just be a single signature made from the single record + for query_sig in sigcoll.sigs.iter() { let query_md5 = query_sig.md5sum(); - let query_mh = sig.minhash().expect("could not get minhash from sig"); + let query_mh = query_sig.minhash().expect("could not get minhash from sig"); let query_name = query_sig.name(); // this is actually just record.id --> so maybe don't get it from sig here? // now do prefetch/gather - let prefetch_result = load_sketches_above_threshold(against_collection, &query_mh, threshold_hashes)?; + let prefetch_result = load_sketches_above_threshold( + against_collection, + &query_mh, + threshold_hashes, + )?; let matchlist = prefetch_result.0; let skipped_paths = prefetch_result.1; let failed_paths = prefetch_result.2; @@ -97,7 +110,7 @@ pub fn fastagather( query_filename.clone(), query_name.clone(), query_md5, - prefetch_output, + prefetch_output.clone(), &matchlist, ) .ok(); @@ -106,11 +119,11 @@ pub fn fastagather( consume_query_by_gather( query_name, query_filename, - query_mh, + query_mh.clone(), scaled as u32, matchlist, threshold_hashes, - gather_output, + gather_output.clone(), ) .ok(); } diff --git a/src/lib.rs b/src/lib.rs index afa5b857..32e03e14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use crate::utils::build_selection; use crate::utils::is_revindex_database; mod check; mod cluster; +mod fastagather; mod fastgather; mod fastmultigather; mod fastmultigather_rocksdb; @@ -359,6 +360,41 @@ fn do_cluster( } } +#[pyfunction] +#[allow(clippy::too_many_arguments)] +#[pyo3(signature = (query_filename, against_filepath, input_moltype, threshold_bp, ksize, scaled, moltype, output_path_prefetch=None, output_path_gather=None))] +fn do_fastagather( + query_filename: String, + against_filepath: String, + input_moltype: String, + threshold_bp: u64, + ksize: u8, + scaled: Option, + moltype: String, + output_path_prefetch: Option, + output_path_gather: Option, +) -> anyhow::Result { + let selection = build_selection(ksize, scaled, &moltype); + let allow_failed_sigpaths = true; + + match fastagather::fastagather( + query_filename, + against_filepath, + input_moltype, + threshold_bp, + &selection, + output_path_prefetch, + output_path_gather, + allow_failed_sigpaths, + ) { + Ok(_) => Ok(0), + Err(e) => { + eprintln!("Error: {e}"); + Ok(1) + } + } +} + /// Module interface for the `sourmash_plugin_branchwater` extension module. #[pymodule] @@ -374,5 +410,6 @@ fn sourmash_plugin_branchwater(_py: Python, m: &Bound<'_, PyModule>) -> PyResult m.add_function(wrap_pyfunction!(do_pairwise, m)?)?; m.add_function(wrap_pyfunction!(do_cluster, m)?)?; m.add_function(wrap_pyfunction!(do_singlesketch, m)?)?; + m.add_function(wrap_pyfunction!(do_fastagather, m)?)?; Ok(()) } diff --git a/src/python/sourmash_plugin_branchwater/__init__.py b/src/python/sourmash_plugin_branchwater/__init__.py index 3de354aa..2fc8f75e 100755 --- a/src/python/sourmash_plugin_branchwater/__init__.py +++ b/src/python/sourmash_plugin_branchwater/__init__.py @@ -800,3 +800,94 @@ def main(self, args): notify(f"...clustering is done! results in '{args.output}'") notify(f" cluster counts in '{args.cluster_sizes}'") return status + + +class Branchwater_Fastagather(CommandLinePlugin): + command = "fastagather" + description = "massively parallel gather directly from FASTA" + + def __init__(self, p): + super().__init__(p) + p.add_argument("query_fa", help="FASTA file") + p.add_argument("against_paths", help="database file of sketches") + p.add_argument( + "-I", + "--input-moltype", + "--input-molecule-type", + choices=["DNA", "dna", "protein"], + default="DNA", + help="molecule type of input sequence (DNA or protein)", + ) + p.add_argument( + "-o", + "--output-gather", + required=True, + help="save gather output (minimum metagenome cover) to this file", + ) + p.add_argument( + "--output-prefetch", help="save prefetch output (all overlaps) to this file" + ) + p.add_argument( + "-t", + "--threshold-bp", + default=4000, + type=float, + help="threshold in estimated base pairs, for reporting matches (default: 4kb)", + ) + p.add_argument( + "-k", + "--ksize", + default=31, + type=int, + help="k-mer size at which to do comparisons (default: 31)", + ) + p.add_argument( + "-s", + "--scaled", + default=1000, + type=int, + help="scaled factor at which to do comparisons (default: 1000)", + ) + p.add_argument( + "-m", + "--moltype", + default="DNA", + choices=["DNA", "protein", "dayhoff", "hp"], + help="molecule type for search (DNA, protein, dayhoff, or hp; default DNA)", + ) + p.add_argument( + "-c", + "--cores", + default=0, + type=int, + help="number of cores to use (default is all available)", + ) + + def main(self, args): + print_version() + notify( + f"ksize: {args.ksize} / scaled: {args.scaled} / moltype: {args.moltype} / threshold bp: {args.threshold_bp}" + ) + + num_threads = set_thread_pool(args.cores) + + notify( + f"gathering all sketches in '{args.query_sig}' against '{args.against_paths}' using {num_threads} threads" + ) + super().main(args) + status = sourmash_plugin_branchwater.do_fastagather( + args.query_fa, + args.against_paths, + args.input_moltype, + int(args.threshold_bp), + args.ksize, + args.scaled, + args.moltype, + args.output_gather, + args.output_prefetch, + ) + if status == 0: + notify(f"...fastgather is done! gather results in '{args.output_gather}'") + if args.output_prefetch: + notify(f"prefetch results in '{args.output_prefetch}'") + return status diff --git a/src/python/tests/sourmash_tst_utils.py b/src/python/tests/sourmash_tst_utils.py index 86c97c57..26364f63 100644 --- a/src/python/tests/sourmash_tst_utils.py +++ b/src/python/tests/sourmash_tst_utils.py @@ -7,8 +7,7 @@ import collections import pprint -import pkg_resources -from pkg_resources import Requirement, resource_filename, ResolutionError +import importlib.metadata import traceback from io import open # pylint: disable=redefined-builtin from io import StringIO @@ -84,24 +83,13 @@ def _runscript(scriptname): namespace = {"__name__": "__main__"} namespace["sys"] = globals()["sys"] - try: - pkg_resources.load_entry_point("sourmash", "console_scripts", scriptname)() - return 0 - except pkg_resources.ResolutionError: - pass - - path = scriptpath() - - scriptfile = os.path.join(path, scriptname) - if os.path.isfile(scriptfile): - if os.path.isfile(scriptfile): - exec( # pylint: disable=exec-used - compile(open(scriptfile).read(), scriptfile, "exec"), namespace - ) - return 0 - - return -1 - + entry_points = importlib.metadata.entry_points( + group="console_scripts", name="sourmash" + ) + assert len(entry_points) == 1 + smash_cli = tuple(entry_points)[0].load() + smash_cli() + return 0 ScriptResults = collections.namedtuple("ScriptResults", ["status", "out", "err"]) diff --git a/src/utils/buildutils.rs b/src/utils/buildutils.rs index 8d75391c..2309ad2d 100644 --- a/src/utils/buildutils.rs +++ b/src/utils/buildutils.rs @@ -67,6 +67,12 @@ impl MultiSelection { selections: selections?, }) } + + pub fn from_selection(selection: Selection) -> Self { + MultiSelection { + selections: vec![selection], + } + } } pub trait MultiSelect { @@ -95,7 +101,7 @@ pub struct BuildRecord { num: u32, #[getset(get = "pub")] - scaled: u64, + scaled: u32, #[getset(get = "pub", set = "pub")] n_hashes: Option, @@ -190,7 +196,7 @@ impl BuildRecord { ksize: record.ksize(), moltype: record.moltype().to_string(), num: *record.num(), - scaled: *record.scaled() as u64, + scaled: *record.scaled() as u32, with_abundance: record.with_abundance(), ..Self::default_dna() // ignore remaining fields } @@ -213,7 +219,7 @@ impl BuildRecord { if let Some(scaled) = selection.scaled() { // num sigs have self.scaled = 0, don't include them - valid = valid && self.scaled != 0 && self.scaled <= scaled as u64; + valid = valid && self.scaled != 0 && self.scaled <= scaled as u32; } if let Some(num) = selection.num() { @@ -223,7 +229,7 @@ impl BuildRecord { valid } - pub fn params(&self) -> (u32, String, bool, u32, u64) { + pub fn params(&self) -> (u32, String, bool, u32, u32) { ( self.ksize, self.moltype.clone(), @@ -285,7 +291,7 @@ impl BuildManifest { self.records.clear(); } - pub fn summarize_params(&self) -> HashSet<(u32, String, bool, u32, u64)> { + pub fn summarize_params(&self) -> HashSet<(u32, String, bool, u32, u32)> { self.iter().map(|record| record.params()).collect() } @@ -500,7 +506,7 @@ impl BuildCollection { Ok(()) } - pub fn summarize_params(&self) -> HashSet<(u32, String, bool, u32, u64)> { + pub fn summarize_params(&self) -> HashSet<(u32, String, bool, u32, u32)> { let params: HashSet<_> = self.manifest.iter().map(|record| record.params()).collect(); // Print a description of the summary @@ -520,7 +526,7 @@ impl BuildCollection { let mut moltype: Option = None; let mut track_abundance: Option = None; let mut num: Option = None; - let mut scaled: Option = None; + let mut scaled: Option = None; let mut seed: Option = None; for item in p_str.split(',') { @@ -645,6 +651,52 @@ impl BuildCollection { collection } + pub fn from_selection(selection: &Selection) -> Result { + let mut collection = BuildCollection::new(); + + // Set a default ksize if none is provided + let ksizes = if let Some(ksize) = selection.ksize() { + vec![ksize] + } else { + vec![21] // Default ksize + }; + + // Default moltype if not provided + let moltype = selection + .moltype() + .clone() + .ok_or("Moltype must be specified in selection")?; + + for ksize in ksizes { + let mut record = match moltype { + HashFunctions::Murmur64Dna => BuildRecord::default_dna(), + HashFunctions::Murmur64Protein => BuildRecord::default_protein(), + HashFunctions::Murmur64Dayhoff => BuildRecord::default_dayhoff(), + HashFunctions::Murmur64Hp => BuildRecord::default_hp(), + _ => { + return Err(format!("Unsupported moltype '{:?}' in selection", moltype)); + } + }; + + // Apply selection parameters to the BuildRecord + record.ksize = ksize; + if let Some(track_abundance) = selection.abund() { + record.with_abundance = track_abundance; + } + if let Some(num) = selection.num() { + record.num = num; + } + if let Some(scaled) = selection.scaled() { + record.scaled = scaled; + } + + // Add the template signature and record to the collection + collection.add_template_sig_from_record(&record); + } + + Ok(collection) + } + pub fn add_template_sig_from_record(&mut self, record: &BuildRecord) { // Adjust ksize for protein, dayhoff, or hp, which require tripling the k-mer size. let adjusted_ksize = match record.moltype.as_str() { @@ -1375,4 +1427,121 @@ mod tests { assert_eq!(added_dayhoff_record.ksize, 10); assert_eq!(added_dayhoff_record.with_abundance, true); } + + #[test] + fn test_from_selection_dna_with_defaults() { + // Create a selection with DNA moltype and default parameters + let selection = Selection::builder() + .moltype(HashFunctions::Murmur64Dna) + .build(); + + // Call from_selection + let build_collection = BuildCollection::from_selection(&selection) + .expect("Failed to create BuildCollection from selection"); + + // Validate that the collection is not empty + assert!( + !build_collection.is_empty(), + "BuildCollection should not be empty" + ); + + // Validate that the manifest contains the correct record + assert_eq!( + build_collection.manifest.size(), + 1, + "Expected one record in the manifest" + ); + + let record = &build_collection.manifest.records[0]; + assert_eq!(record.moltype, "dna", "Expected moltype to be 'dna'"); + assert_eq!(record.ksize, 21, "Expected default ksize to be 21"); + assert!( + !record.with_abundance, + "Expected default abundance to be false" + ); + } + + #[test] + fn test_from_selection_with_custom_parameters() { + // Create a selection with custom parameters + let selection = Selection::builder() + .moltype(HashFunctions::Murmur64Protein) + .ksize(31) + .abund(true) + .scaled(1000) + .build(); + + // Call from_selection + let build_collection = BuildCollection::from_selection(&selection) + .expect("Failed to create BuildCollection from selection"); + + // Validate that the collection is not empty + assert!( + !build_collection.is_empty(), + "BuildCollection should not be empty" + ); + + // Validate that the manifest contains the correct record + assert_eq!( + build_collection.manifest.size(), + 1, + "Expected one record in the manifest" + ); + + let record = &build_collection.manifest.records[0]; + assert_eq!( + record.moltype, "protein", + "Expected moltype to be 'protein'" + ); + assert_eq!(record.ksize, 31, "Expected ksize to be 31"); + assert!(record.with_abundance, "Expected abundance to be true"); + assert_eq!(record.scaled, 1000, "Expected scaled to be 1000"); + } + + #[test] + fn test_from_selection_multiple_ksizes() { + // Create a selection with multiple ksizes + let selection = Selection::builder() + .moltype(HashFunctions::Murmur64Dayhoff) + .ksize(21) // Simulate multiple ksizes by changing test logic + .build(); + + // Call from_selection + let build_collection = BuildCollection::from_selection(&selection) + .expect("Failed to create BuildCollection from selection"); + + // Validate that the collection contains the correct number of records + assert!( + !build_collection.is_empty(), + "BuildCollection should not be empty" + ); + + assert_eq!( + build_collection.manifest.size(), + 1, + "Expected one record in the manifest" + ); + + let record = &build_collection.manifest.records[0]; + assert_eq!( + record.moltype, "dayhoff", + "Expected moltype to be 'dayhoff'" + ); + assert_eq!(record.ksize, 21, "Expected ksize to be 21"); + } + + #[test] + fn test_from_selection_missing_moltype() { + // Create a selection without a moltype + let selection = Selection::builder().ksize(31).build(); + + // Call from_selection and expect an error + let result = BuildCollection::from_selection(&selection); + assert!(result.is_err(), "Expected an error due to missing moltype"); + assert_eq!( + result.unwrap_err(), + "Moltype must be specified in selection", + "Unexpected error message" + ); + } }