From a76dfc24f6999a897589160d7b3be8426286faad Mon Sep 17 00:00:00 2001 From: JeanMainguy Date: Mon, 4 Dec 2023 23:36:27 +0100 Subject: [PATCH] add test and refactor main function --- binette/bin_manager.py | 92 ++-------- binette/bin_quality.py | 52 +----- binette/binette.py | 271 ++++++++++++++++------------ binette/cds.py | 8 +- binette/io_manager.py | 27 ++- tests/bin_manager_test.py | 359 ++++++++++++++++++++++++++++++++++++- tests/bin_quality_test.py | 315 +++++++++++++++++++++++++++++++- tests/cds_test.py | 8 +- tests/diamonds_test.py | 12 +- tests/io_manager_test.py | 32 ++++ tests/main_binette_test.py | 319 ++++++++++++++++++++++++++++++++ 11 files changed, 1240 insertions(+), 255 deletions(-) create mode 100644 tests/main_binette_test.py diff --git a/binette/bin_manager.py b/binette/bin_manager.py index 9f75848..c35b9c2 100644 --- a/binette/bin_manager.py +++ b/binette/bin_manager.py @@ -81,23 +81,30 @@ def __and__(self, other: 'Bin') -> 'Bin': return Bin(contigs, origin, name) - def add_length(self, length: float) -> None: + def add_length(self, length: int) -> None: """ - Add the length attribute to the Bin object. + Add the length attribute to the Bin object if the provided length is a positive integer. :param length: The length value to add. :return: None """ - self.length = length + if isinstance(length, int) and length > 0: + self.length = length + else: + raise ValueError("Length should be a positive integer.") - def add_N50(self, n50: float) -> None: + def add_N50(self, n50: int) -> None: """ Add the N50 attribute to the Bin object. :param n50: The N50 value to add. :return: None """ - self.N50 = n50 + if isinstance(n50, int) and n50 >= 0: + self.N50 = n50 + else: + raise ValueError("N50 should be a positive integer.") + def add_quality(self, completeness: float, contamination: float, contamination_weight: float) -> None: """ @@ -106,7 +113,6 @@ def add_quality(self, completeness: float, contamination: float, contamination_w :param completeness: The completeness value. :param contamination: The contamination value. :param contamination_weight: The weight assigned to contamination in the score calculation. - :return: None """ self.completeness = completeness self.contamination = contamination @@ -260,50 +266,6 @@ def from_bin_sets_to_bin_graph(bin_name_to_bin_set: Dict[str, set]) -> nx.Graph: return G -def get_bin_graph(bins: List[Bin]) -> nx.Graph: - """ - Creates a bin graph from a list of Bin objects. - - :param bins: A list of Bin objects representing bins. - - :return: A networkx Graph representing the bin graph of overlapping bins. - """ - G = nx.Graph() - G.add_nodes_from((b.id for b in bins)) - - for i, (bin1, bin2) in enumerate(itertools.combinations(bins, 2)): - - if bin1.overlaps_with(bin2): - # logging.info(f"{bin1} overlaps with {bin2}") - G.add_edge( - bin1.id, - bin2.id, - ) - return G - - -def get_bin_graph_with_attributes(bins: List[Bin], contig_to_length: Dict[str, int]) -> nx.Graph: - """ - Creates a graph from a list of Bin objects with additional attributes. - - :param bins: A list of Bin objects representing bins. - :param contig_to_length: A dictionary mapping contig names to their lengths. - - :return: A networkx Graph representing the bin graph with attributes. - """ - G = nx.Graph() - G.add_nodes_from((b.id for b in bins)) - - for i, (bin1, bin2) in enumerate(itertools.combinations(bins, 2)): - if bin1.overlaps_with(bin2): - - contigs = bin1.contigs & bin2.contigs - shared_length = sum((contig_to_length[c] for c in contigs)) - max_shared_length_prct = 100 - 100 * (shared_length / min((bin1.length, bin2.length))) - - G.add_edge(bin1.id, bin2.id, weight=max_shared_length_prct) - return G - def get_all_possible_combinations(clique: Iterable) -> Iterable[Tuple]: """ @@ -353,8 +315,7 @@ def get_difference_bins(G: nx.Graph) -> Set[Bin]: difference_bins = set() for clique in nx.clique.find_cliques(G): - # TODO should not use combinations but another method of itertools - # to get all possible combination in all possible order. + bins_combinations = get_all_possible_combinations(clique) for bins in bins_combinations: @@ -376,7 +337,7 @@ def get_union_bins(G: nx.Graph, max_conta: int = 50) -> Set[Bin]: """ Retrieves the union bins from a given graph. - :param G: A networkx Graph representing the graph. + :param G: A networkx Graph representing the graph of bins. :param max_conta: Maximum allowed contamination value for a bin to be included in the union. :return: A set of Bin objects representing the union bins. @@ -400,31 +361,6 @@ def get_union_bins(G: nx.Graph, max_conta: int = 50) -> Set[Bin]: return union_bins -def create_intersec_diff_bins(G: nx.Graph) -> Set[Bin]: - """ - Creates intersection and difference bins from a given graph. - - :param G: A networkx Graph representing the graph. - - :return: A set of Bin objects representing the intersection and difference bins. - """ - new_bins = set() - - for clique in nx.clique.find_cliques(G): - bins_combinations = get_all_possible_combinations(clique) - for bins in bins_combinations: - - # intersection - intersec_bin = bins[0].intersection(*bins[1:]) - new_bins.add(intersec_bin) - - # difference - for bin_a in bins: - bin_diff = bin_a.difference(*(b for b in bins if b != bin_a)) - new_bins.add(bin_diff) - - return new_bins - def select_best_bins(bins: List[Bin]) -> List[Bin]: """ Selects the best bins from a list of bins based on their scores, N50 values, and IDs. diff --git a/binette/bin_quality.py b/binette/bin_quality.py index 3fb0c20..b9303cf 100644 --- a/binette/bin_quality.py +++ b/binette/bin_quality.py @@ -135,51 +135,6 @@ def add_bin_size_and_N50(bins: Iterable[Bin], contig_to_size: Dict[str,int]): bin_obj.add_N50(n50) - -def get_bin_size_and_N50(bin_obj, contig_to_size: Dict[str, int]): - """ - Calculate and add bin size and N50 to a bin object. - - :param bin_obj: The bin object to calculate size and N50 for. - :type bin_obj: Any - :param contig_to_size: Dictionary mapping contig names to their sizes. - :type contig_to_size: Dict[str, int] - """ - lengths = [contig_to_size[c] for c in bin_obj.contigs] - n50 = compute_N50(lengths) - - bin_obj.add_length(sum(lengths)) - bin_obj.add_N50(n50) - - -def add_bin_metrics_in_parallel(bins: List, contig_info: Dict, threads: int, contamination_weight: float): - """ - Add bin metrics in parallel for a list of bins. - - :param bins: List of bin objects. - :type bins: List - :param contig_info: Dictionary containing contig information. - :type contig_info: Dict - :param threads: Number of threads to use for parallel processing. - :type threads: int - :param contamination_weight: Weight for contamination assessment. - :type contamination_weight: float - :return: Set of processed bin objects. - :rtype: Set - """ - chunk_size = int(len(bins) / threads) + 1 - print("CHUNK SIZE TO PARALLELIZE", chunk_size) - results = [] - with cf.ProcessPoolExecutor(max_workers=threads) as tpe: - for i, bins_chunk in enumerate(chunks(bins, chunk_size)): - print(f"chunk {i}, {len(bins_chunk)} bins") - results.append(tpe.submit(add_bin_metrics, *(bins_chunk, contig_info, contamination_weight))) - - processed_bins = {bin_o for r in results for bin_o in r.result()} - return processed_bins - - - def add_bin_metrics(bins: List, contig_info: Dict, contamination_weight: float, threads: int = 1): """ Add metrics to a list of bins. @@ -235,7 +190,8 @@ def assess_bins_quality_by_chunk(bins: List, contig_to_aa_length: Dict, contamination_weight: float, postProcessor:modelPostprocessing.modelProcessor = None, - threads: int = 1,): + threads: int = 1, + chunk_size: int = 2500): """ Assess the quality of bins in chunks. @@ -249,10 +205,10 @@ def assess_bins_quality_by_chunk(bins: List, :param contamination_weight: Weight for contamination assessment. :param postProcessor: post-processor from checkm2 :param threads: Number of threads for parallel processing (default is 1). + :param chunk_size: The size of each chunk. """ - n = 2500 - for i, chunk_bins_iter in enumerate(chunks(bins, n)): + for i, chunk_bins_iter in enumerate(chunks(bins, chunk_size)): chunk_bins = set(chunk_bins_iter) logging.debug(f"chunk {i}: assessing quality of {len(chunk_bins)}") assess_bins_quality( diff --git a/binette/binette.py b/binette/binette.py index 776b3ff..013bcf4 100755 --- a/binette/binette.py +++ b/binette/binette.py @@ -16,6 +16,7 @@ import pkg_resources from binette import contig_manager, cds, diamond, bin_quality, bin_manager, io_manager as io +from typing import List, Dict, Set, Tuple def init_logging(verbose, debug): @@ -39,7 +40,7 @@ def init_logging(verbose, debug): ) -def parse_arguments(): +def parse_arguments(args): """Parse script arguments.""" program_version = pkg_resources.get_distribution("Binette").version @@ -112,60 +113,24 @@ def parse_arguments(): parser.add_argument("--version", action="version", version=program_version) - args = parser.parse_args() + args = parser.parse_args(args) return args -def main(): - "Orchestrate the execution of the program" - - args = parse_arguments() - - init_logging(args.verbose, args.debug) - - # Setup input parameters # - - bin_dirs = args.bin_dirs - contig2bin_tables = args.contig2bin_tables - contigs_fasta = args.contigs - threads = args.threads - outdir = args.outdir - low_mem = args.low_mem - contamination_weight = args.contamination_weight - - min_completeness = args.min_completeness +def parse_input_files(bin_dirs: List[str], contig2bin_tables: List[str], contigs_fasta: str) -> Tuple[Dict[str, List], List, Dict[str, List], Dict[str, int]]: + """ + Parses input files to retrieve information related to bins and contigs. - # High quality threshold used just to log number of high quality bins. - hq_max_conta = 5 - hq_min_completeness = 90 - - # Temporary files # - out_tmp_dir = os.path.join(outdir, "temporary_files") - os.makedirs(out_tmp_dir, exist_ok=True) + :param bin_dirs: List of paths to directories containing bin FASTA files. + :param contig2bin_tables: List of paths to contig-to-bin tables. + :param contigs_fasta: Path to the contigs FASTA file. - faa_file = os.path.join(out_tmp_dir, "assembly_proteins.faa") - diamond_result_file = os.path.join(out_tmp_dir, "diamond_result.tsv") - diamond_log = os.path.join(out_tmp_dir, "diamond_run.log") - - # Output files # - outdir_final_bin_set = os.path.join(outdir, "final_bins") - os.makedirs(outdir_final_bin_set, exist_ok=True) - - final_bin_report = os.path.join(outdir, "final_bins_quality_reports.tsv") - - # Flag parameters - resume = args.resume - debug = args.debug - - if resume and not os.path.isfile(faa_file): - logging.error(f"Protein file {faa_file} does not exist. Resuming is not possible") - exit(1) - - if resume and not os.path.isfile(diamond_result_file): - logging.error(f"Diamond result file {diamond_result_file} does not exist. Resuming is not possible") - exit(1) - - # Loading input bin sets + :return: A tuple containing: + - Dictionary mapping bin set names to lists of bins. + - List of original bins. + - Dictionary mapping bins to lists of contigs. + - Dictionary mapping contig names to their lengths. + """ if bin_dirs: logging.info("Parsing bin directories.") @@ -183,10 +148,32 @@ def main(): original_bins = bin_manager.dereplicate_bin_sets(bin_set_name_to_bins.values()) contigs_in_bins = bin_manager.get_contigs_in_bins(original_bins) - logging.info("Parsing contig fasta file: {contigs_fasta}") + logging.info(f"Parsing contig fasta file: {contigs_fasta}") contigs_object = contig_manager.parse_fasta_file(contigs_fasta) contig_to_length = {seq.name: len(seq) for seq in contigs_object if seq.name in contigs_in_bins} + return bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length + + +def manage_protein_alignement(faa_file: str, contigs_fasta: str, contig_to_length: Dict[str, List], + contigs_in_bins: Dict[str, List], diamond_result_file: str, + checkm2_db: str, threads: int, resume: bool, low_mem: bool) -> Tuple[Dict[str, int], Dict[str, int]]: + """ + Predicts or reuses proteins prediction and runs diamond on them. + + :param faa_file: The path to the .faa file. + :param contigs_fasta: The path to the contigs FASTA file. + :param contig_to_length: Dictionary mapping contig names to their lengths. + :param contigs_in_bins: Dictionary mapping bin names to lists of contigs. + :param diamond_result_file: The path to the diamond result file. + :param checkm2_db: The path to the CheckM2 database. + :param threads: Number of threads for parallel processing. + :param resume: Boolean indicating whether to resume the process. + :param low_mem: Boolean indicating whether to use low memory mode. + + :return: A tuple containing dictionaries - contig_to_kegg_counter and contig_to_genes. + """ + # Predict or reuse proteins prediction and run diamond on them if resume: logging.info(f"Parsing faa file: {faa_file}.") @@ -197,12 +184,14 @@ def main(): contigs_iterator = (s for s in contig_manager.parse_fasta_file(contigs_fasta) if s.name in contigs_in_bins) contig_to_genes = cds.predict(contigs_iterator, faa_file, threads) - if not args.checkm2_db: + if checkm2_db: + diamond_db_path = checkm2_db + else: # get checkm2 db stored in checkm2 install diamond_db_path = diamond.get_checkm2_db() - else: - diamond_db_path = args.checkm2_db - + + diamond_log = f"{os.path.splitext(diamond_result_file)[0]}.log" + diamond.run( faa_file, diamond_result_file, @@ -214,47 +203,30 @@ def main(): logging.info("Parsing diamond results.") contig_to_kegg_counter = diamond.get_contig_to_kegg_id(diamond_result_file) + # Check contigs from diamond vs input assembly consistency io.check_contig_consistency(contig_to_length, contig_to_kegg_counter, contigs_fasta, diamond_result_file) - # Use contig index instead of contig name to save memory - contig_to_index, index_to_contig = contig_manager.make_contig_index(contigs_in_bins) - - contig_to_kegg_counter = contig_manager.apply_contig_index(contig_to_index, contig_to_kegg_counter) - contig_to_genes = contig_manager.apply_contig_index(contig_to_index, contig_to_genes) - contig_to_length = contig_manager.apply_contig_index(contig_to_index, contig_to_length) - - bin_manager.rename_bin_contigs(original_bins, contig_to_index) - - # Extract cds metadata ## - - logging.info("Compute cds metadata.") - ( - contig_to_cds_count, - contig_to_aa_counter, - contig_to_aa_length, - ) = cds.get_contig_cds_metadata(contig_to_genes, threads) - - contig_info = { - "contig_to_cds_count": contig_to_cds_count, - "contig_to_aa_counter": contig_to_aa_counter, - "contig_to_aa_length": contig_to_aa_length, - "contig_to_kegg_counter": contig_to_kegg_counter, - "contig_to_length": contig_to_length, - } - - logging.info("Add size and assess quality of input bins") + return contig_to_kegg_counter, contig_to_genes - bin_quality.add_bin_metrics(original_bins, contig_info, contamination_weight, threads) - logging.info("Create intermediate bins:") - new_bins = bin_manager.create_intermediate_bins(bin_set_name_to_bins) +def select_bins_and_write_them(all_bins: Set[bin_manager.Bin], contigs_fasta: str, final_bin_report: str, min_completeness: float, + index_to_contig: dict, outdir: str, debug: bool) -> List[bin_manager.Bin]: + """ + Selects and writes bins based on specific criteria. - logging.info("Assess quality for supplementary intermediate bins.") - new_bins = bin_quality.add_bin_metrics(new_bins, contig_info, contamination_weight, threads) + :param all_bins: Set of Bin objects. + :param contigs_fasta: Path to the contigs FASTA file. + :param final_bin_report: Path to write the final bin report. + :param min_completeness: Minimum completeness threshold for bin selection. + :param index_to_contig: Dictionary mapping indices to contig names. + :param outdir: Output directory to save final bins and reports. + :param debug: Debug mode flag. + :return: Selected bins that meet the completeness threshold. + """ - logging.info("Dereplicating input bins and new bins") - all_bins = original_bins | new_bins + outdir_final_bin_set = os.path.join(outdir, "final_bins") + os.makedirs(outdir_final_bin_set, exist_ok=True) if debug: all_bins_for_debug = set(all_bins) @@ -267,9 +239,11 @@ def main(): with open(os.path.join(outdir, "index_to_contig.tsv"), 'w') as flout: flout.write('\n'.join((f'{i}\t{c}' for i, c in index_to_contig.items()))) - logging.info("Select best bins") + logging.info("Selecting best bins") selected_bins = bin_manager.select_best_bins(all_bins) - + + logging.info(f"Bin Selection: {len(selected_bins)} selected bins") + logging.info(f"Filtering bins: only bins with completeness >= {min_completeness} are kept") selected_bins = [b for b in selected_bins if b.completeness >= min_completeness] @@ -284,24 +258,101 @@ def main(): io.write_bins_fasta(selected_bins, contigs_fasta, outdir_final_bin_set) - if debug: - for sb in selected_bins: - if sb.completeness >= hq_min_completeness and sb.contamination <= hq_max_conta: - logging.debug(f"{sb}, {sb.completeness}, {sb.contamination}") + return selected_bins - hq_bins = len( - [sb for sb in selected_bins if sb.completeness >= hq_min_completeness and sb.contamination <= hq_max_conta] - ) - hq_bins_single = len( - [ - sb - for sb in selected_bins - if sb.completeness >= hq_min_completeness and sb.contamination <= hq_max_conta and len(sb.contigs) == 1 - ] - ) - tresholds = f"(completeness >= {hq_min_completeness} and contamination <= {hq_max_conta})" - logging.info(f"{hq_bins}/{len(selected_bins)} selected bins have a high quality {tresholds}.") - logging.info( - f"{hq_bins_single}/{len(selected_bins)} selected bins have a high quality and are made of only one contig." - ) + +def log_selected_bin_info(selected_bins: List[bin_manager.Bin], hq_min_completeness: float, hq_max_conta: float): + """ + Log information about selected bins based on quality thresholds. + + :param selected_bins: List of Bin objects to analyze. + :param hq_min_completeness: Minimum completeness threshold for high-quality bins. + :param hq_max_conta: Maximum contamination threshold for high-quality bins. + + This function logs information about selected bins that meet specified quality thresholds. + It counts the number of high-quality bins based on completeness and contamination values. + """ + + # Log completeness and contamination in debug log + logging.debug("High quality bins:") + for sb in selected_bins: + if sb.completeness >= hq_min_completeness and sb.contamination <= hq_max_conta: + logging.debug(f"> {sb} completeness={sb.completeness}, contamination={sb.contamination}") + + # Count high-quality bins and single-contig high-quality bins + hq_bins = len([sb for sb in selected_bins if sb.completeness >= hq_min_completeness and sb.contamination <= hq_max_conta]) + + # Log information about high-quality bins + thresholds = f"(completeness >= {hq_min_completeness} and contamination <= {hq_max_conta})" + logging.info(f"{hq_bins}/{len(selected_bins)} selected bins have a high quality {thresholds}.") + + +def main(): + "Orchestrate the execution of the program" + + args = parse_arguments(sys.argv[1:]) # sys.argv is passed in order to be able to test the function parse_arguments + + init_logging(args.verbose, args.debug) + + # High quality threshold used just to log number of high quality bins. + hq_max_conta = 5 + hq_min_completeness = 90 + + # Temporary files # + out_tmp_dir = os.path.join(args.outdir, "temporary_files") + os.makedirs(out_tmp_dir, exist_ok=True) + + faa_file = os.path.join(out_tmp_dir, "assembly_proteins.faa") + diamond_result_file = os.path.join(out_tmp_dir, "diamond_result.tsv") + + # Output files # + final_bin_report = os.path.join(args.outdir, "final_bins_quality_reports.tsv") + + + if args.resume: + io.check_resume_file(faa_file, diamond_result_file) + + bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = parse_input_files(args.bin_dirs, args.contig2bin_tables, args.contigs) + + contig_to_kegg_counter, contig_to_genes = manage_protein_alignement(faa_file=faa_file, contigs_fasta=args.contigs, contig_to_length=contig_to_length, + contigs_in_bins=contigs_in_bins, + diamond_result_file=diamond_result_file, checkm2_db=args.checkm2_db, + threads=args.threads, resume=args.resume, low_mem=args.low_mem) + + # Use contig index instead of contig name to save memory + contig_to_index, index_to_contig = contig_manager.make_contig_index(contigs_in_bins) + + contig_to_kegg_counter = contig_manager.apply_contig_index(contig_to_index, contig_to_kegg_counter) + contig_to_genes = contig_manager.apply_contig_index(contig_to_index, contig_to_genes) + contig_to_length = contig_manager.apply_contig_index(contig_to_index, contig_to_length) + + bin_manager.rename_bin_contigs(original_bins, contig_to_index) + + + # Extract cds metadata ## + logging.info("Compute cds metadata.") + contig_metadat = cds.get_contig_cds_metadata(contig_to_genes, args.threads) + + contig_metadat["contig_to_kegg_counter"] = contig_to_kegg_counter + contig_metadat["contig_to_length"] = contig_to_length + + + logging.info("Add size and assess quality of input bins") + bin_quality.add_bin_metrics(original_bins, contig_metadat, args.contamination_weight, args.threads) + + logging.info("Create intermediate bins:") + new_bins = bin_manager.create_intermediate_bins(bin_set_name_to_bins) + + logging.info("Assess quality for supplementary intermediate bins.") + new_bins = bin_quality.add_bin_metrics(new_bins, contig_metadat, args.contamination_weight, args.threads) + + + logging.info("Dereplicating input bins and new bins") + all_bins = original_bins | new_bins + + selected_bins = select_bins_and_write_them(all_bins, args.contigs, final_bin_report, args.min_completeness, index_to_contig, args.outdir, args.debug) + + log_selected_bin_info(selected_bins, hq_min_completeness, hq_max_conta) + + return 0 \ No newline at end of file diff --git a/binette/cds.py b/binette/cds.py index 06284e3..cebb76e 100644 --- a/binette/cds.py +++ b/binette/cds.py @@ -137,4 +137,10 @@ def get_contig_cds_metadata(contig_to_genes: Dict[str, List[str]], threads: int) contig_to_aa_length = {contig: sum(counter.values()) for contig, counter in tqdm(contig_to_aa_counter.items(), unit="contig")} logging.info("Calculating total amino acid length in parallel.") - return contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length + contig_info = { + "contig_to_cds_count": contig_to_cds_count, + "contig_to_aa_counter": contig_to_aa_counter, + "contig_to_aa_length": contig_to_aa_length, + } + + return contig_info \ No newline at end of file diff --git a/binette/io_manager.py b/binette/io_manager.py index 0267852..467bcc3 100644 --- a/binette/io_manager.py +++ b/binette/io_manager.py @@ -108,4 +108,29 @@ def check_contig_consistency(contigs_from_assembly: List[str], message = f"{issue_countigs} contigs found in file {elsewhere_file} \ were not found in assembly_file ({assembly_file})." - assert are_contigs_consistent, message \ No newline at end of file + assert are_contigs_consistent, message + + +def check_resume_file(faa_file: str, diamond_result_file: str) -> None: + """ + Check the existence of files required for resuming the process. + + :param faa_file: Path to the protein file. + :param diamond_result_file: Path to the Diamond result file. + :raises FileNotFoundError: If the required files don't exist for resuming. + """ + + if os.path.isfile(faa_file) and os.path.isfile(diamond_result_file): + return + + if not os.path.isfile(faa_file): + error_msg = f"Protein file '{faa_file}' does not exist. Resuming is not possible." + logging.error(error_msg) + raise FileNotFoundError(error_msg) + + if not os.path.isfile(diamond_result_file): + error_msg = f"Diamond result file '{diamond_result_file}' does not exist. Resuming is not possible." + logging.error(error_msg) + raise FileNotFoundError(error_msg) + + diff --git a/tests/bin_manager_test.py b/tests/bin_manager_test.py index 93683bf..906ab7c 100644 --- a/tests/bin_manager_test.py +++ b/tests/bin_manager_test.py @@ -5,8 +5,8 @@ import pytest -from binette import bin_manager - +from binette import bin_manager, binette +import networkx as nx def test_get_all_possible_combinations(): input_list = ["2", "3", "4"] @@ -21,7 +21,11 @@ def example_bin_set1(): bin2 = bin_manager.Bin(contigs={"3", "4"}, origin="test1", name="bin2") bin3 = bin_manager.Bin(contigs={"5"}, origin="test1", name="bin2") return {bin1, bin2, bin3} - +@pytest.fixture +def example_bin_set2(): + bin1 = bin_manager.Bin(contigs={"1", "2", "3"}, origin="test2", name="binA") + bin2 = bin_manager.Bin(contigs={"4", "5"}, origin="test2", name="binB") + return {bin1, bin2} def test_bin_eq_true(): bin1 = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") @@ -48,6 +52,49 @@ def test_in_for_bin_list(): assert bin2 in bins assert bin3 not in bins +def test_add_length_positive_integer(): + bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") + length = 100 + bin_obj.add_length(length) + assert bin_obj.length == length + +def test_add_length_negative_integer(): + bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") + with pytest.raises(ValueError): + length = -50 + bin_obj.add_length(length) + +def test_add_n50_positive_integer(): + bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") + n50 = 100 + bin_obj.add_N50(n50) + assert bin_obj.N50 == n50 + +def test_add_n50_negative_integer(): + bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") + with pytest.raises(ValueError): + n50 = -50 + bin_obj.add_N50(n50) + +def test_add_quality(): + + completeness = 10 + contamination = 6 + contamination_weight = 2 + + bin_obj = bin_manager.Bin(contigs={"1", "2", "e"}, origin="test1", name="bin1") + + bin_obj.add_quality(completeness, contamination, contamination_weight) + + assert bin_obj.completeness == completeness + assert bin_obj.contamination == contamination + + assert bin_obj.score == completeness - contamination * contamination_weight + + + + + def test_two_bin_intersection(): bin1 = bin_manager.Bin(contigs={"1", "2", "e", "987"}, origin="test1", name="bin1") @@ -196,3 +243,309 @@ def test_select_best_bins_with_equality(): # when score and n50 is the same, selection is made on the smallest id. # bin created first have a smaller id. so b1 should selected assert bin_manager.select_best_bins({b1, b2, b3}) == [b1, b3] + + +# The function should create intersection bins when there are overlapping contigs between bins. +def test_intersection_bins_created(): + set1 = [ + bin_manager.Bin(contigs={"1", "2"}, origin="A", name="bin1"), + bin_manager.Bin(contigs={"3", "4"}, origin="A", name="bin2"), + bin_manager.Bin(contigs={"5"}, origin="A", name="bin2"), + ] + # need to defined completeness and conta + # because when too low the bin is not used in all operation + for b in set1: + b.completeness = 100 + b.contamination = 0 + + binA = bin_manager.Bin(contigs={"1", "3"}, origin="B", name="binA") + binA.contamination = 0 + binA.completeness = 100 + set2 = [ + binA, + ] + bin_set_name_to_bins = {'set1': set1, + "set2":set2} + + intermediate_bins_result = bin_manager.create_intermediate_bins(bin_set_name_to_bins) + + expected_intermediate_bins = {bin_manager.Bin(contigs={"1", "2", "3"}, origin="bin1 | binA ", name="NA"), + bin_manager.Bin(contigs={"2"}, origin="bin1 - binA ", name="NA"), + bin_manager.Bin(contigs={"1"}, origin="bin1 & binA ", name="NA"), + bin_manager.Bin(contigs={"1", "4", "3"}, origin="bin2 | binA ", name="NA"), + bin_manager.Bin(contigs={"4"}, origin="bin2 - binA ", name="NA"), # binA -bin1 is equal to bin1 & binA + bin_manager.Bin(contigs={"3"}, origin="bin1 & binA ", name="NA"), + } + + assert intermediate_bins_result == expected_intermediate_bins + + +# Renames contigs in bins based on provided mapping. +def test_renames_contigs(example_bin_set1): + + bin_set = [ + bin_manager.Bin(contigs={"c1", "c2"}, origin="A", name="bin1"), + bin_manager.Bin(contigs={"c3", "c4"}, origin="A", name="bin2") + ] + + contig_to_index = {'c1': 1, 'c2': 2, 'c3': 3, 'c4': 4, "c5":5} + + # Act + bin_manager.rename_bin_contigs(bin_set, contig_to_index) + + # Assert + assert bin_set[0].contigs == {1, 2} + assert bin_set[0].hash == hash(str(sorted({1, 2}))) + assert bin_set[1].contigs == {3, 4} + assert bin_set[1].hash == hash(str(sorted({3, 4}))) + +def test_get_contigs_in_bins(): + bin_set = [ + bin_manager.Bin(contigs={"c1", "c2"}, origin="A", name="bin1"), + bin_manager.Bin(contigs={"c3", "c4"}, origin="A", name="bin2"), + bin_manager.Bin(contigs={"c3", "c18"}, origin="A", name="bin2") + ] + + contigs = bin_manager.get_contigs_in_bins(bin_set) + + assert contigs == {"c1", "c2", "c3", "c4", "c18"} + + +def test_dereplicate_bin_sets(): + b1 = bin_manager.Bin(contigs={"c1", "c2"}, origin="A", name="bin1") + b2 = bin_manager.Bin(contigs={"c3", "c4"}, origin="A", name="bin2") + b3 = bin_manager.Bin(contigs={"c3", "c18"}, origin="A", name="bin2") + + bdup = bin_manager.Bin(contigs={"c3", "c18"}, origin="D", name="C") + + derep_bins_result = bin_manager.dereplicate_bin_sets([[b2,b3 ],[b1, bdup]]) + + + assert derep_bins_result == {b1, b2, b3} + + + +def test_from_bin_sets_to_bin_graph(): + + + bin1 = bin_manager.Bin(contigs={"1", "2"}, origin="A", name="bin1") + bin2 = bin_manager.Bin(contigs={"3", "4"}, origin="A", name="bin2") + bin3 = bin_manager.Bin(contigs={"5"}, origin="A", name="bin3") + + set1 = [bin1,bin2,bin3] + + binA = bin_manager.Bin(contigs={"1", "3"}, origin="B", name="binA") + + set2 = [binA] + + result_graph = bin_manager.from_bin_sets_to_bin_graph({"B":set2, "A":set1}) + + assert result_graph.number_of_edges() == 2 + # bin3 is not connected to any bin so it is not in the graph + assert result_graph.number_of_nodes() == 3 + + assert set(result_graph.nodes) == {binA, bin1, bin2} + +@pytest.fixture +def simple_bin_graph(): + + bin1 = bin_manager.Bin(contigs={"1", "2", "3"}, origin="A", name="bin1") + bin2 = bin_manager.Bin(contigs={"1", "2", "4"}, origin="B", name="bin2") + + for b in [bin1, bin2]: + b.completeness = 100 + b.contamination = 0 + + G = nx.Graph() + G.add_edge(bin1, bin2) + + return G + + +def test_get_intersection_bins(simple_bin_graph): + + intersec_bins = bin_manager.get_intersection_bins(simple_bin_graph) + + assert len(intersec_bins) == 1 + intersec_bin = intersec_bins.pop() + + assert intersec_bin.contigs == {"1", "2"} + +def test_get_difference_bins(simple_bin_graph): + + difference_bins = bin_manager.get_difference_bins(simple_bin_graph) + + expected_bin1 = bin_manager.Bin(contigs={"3"}, origin="D", name="1") + expected_bin2 = bin_manager.Bin(contigs={"4"}, origin="D", name="2") + + assert len(difference_bins) == 2 + assert difference_bins == {expected_bin1,expected_bin2} + + +def test_get_union_bins(simple_bin_graph): + + u_bins = bin_manager.get_union_bins(simple_bin_graph) + + expected_bin1 = bin_manager.Bin(contigs={"1", "2", "3", "4"}, origin="U", name="1") + + assert len(u_bins) == 1 + assert u_bins == {expected_bin1} + + + + +def test_get_bins_from_contig2bin_table(tmp_path): + # Create a temporary file (contig-to-bin table) for testing + test_table_content = [ + "# Sample contig-to-bin table", + "contig1\tbin1", + "contig2\tbin1", + "contig3\tbin2", + ] + test_table_path = tmp_path / "test_contig2bin_table.txt" + test_table_path.write_text("\n".join(test_table_content)) + + # Define set name for the bins + set_name = "TestSet" + + # Call the function to generate Bin objects + result_bins = bin_manager.get_bins_from_contig2bin_table(str(test_table_path), set_name) + + # Validate the result + assert len(result_bins) == 2 # Check if the correct number of bins are created + + # Define expected bins based on the test table content + expected_bins = [ + bin_manager.Bin(contigs={"contig1", "contig2"}, origin="A", name="bin1"), + bin_manager.Bin(contigs={"contig3"}, origin="A", name="bin2") + ] + + # Compare expected bins with the result + assert all(expected_bin in result_bins for expected_bin in expected_bins) + assert all(result_bin in expected_bins for result_bin in result_bins) + + +def test_parse_contig2bin_tables(tmp_path): + # Create temporary contig-to-bin tables for testing + test_tables = { + "set1": [ + "# Sample contig-to-bin table for bin1", + "contig1\tbin1", + "contig2\tbin1", + "contig3\tbin2", + ], + "set2": [ + "# Sample contig-to-bin table for bin2", + "contig3\tbinA", + "contig4\tbinA" + ] + } + + # Create temporary files for contig-to-bin tables + for name, content in test_tables.items(): + table_path = tmp_path / f"test_{name}_contig2bin_table.txt" + table_path.write_text("\n".join(content)) + + # Call the function to parse contig-to-bin tables + result_bin_dict = bin_manager.parse_contig2bin_tables({name: str(tmp_path / f"test_{name}_contig2bin_table.txt") for name in test_tables}) + + # Validate the result + assert len(result_bin_dict) == len(test_tables) # Check if the number of bins matches the number of tables + + # Define expected Bin objects based on the test tables + expected_bins = { + "set1": [ + bin_manager.Bin(contigs={"contig1", "contig2"}, origin="set1", name="bin1"), + bin_manager.Bin(contigs={"contig3"}, origin="set1", name="bin2"), + ], + "set2": [ + bin_manager.Bin(contigs={"contig3", "contig4"}, origin="set2", name="binA"), + ] + } + + # Compare expected bins with the result + for name, expected in expected_bins.items(): + assert name in result_bin_dict + assert len(result_bin_dict[name]) == len(expected) + for result_bin, expected_bin in zip(result_bin_dict[name], expected): + assert result_bin.contigs == expected_bin.contigs + assert result_bin.name == expected_bin.name + assert result_bin.origin == expected_bin.origin + + + + +@pytest.fixture +def create_temp_bin_files(tmpdir): + # Create temporary bin files + bin_dir = tmpdir.mkdir("bins") + bin1 = bin_dir.join("bin1.fasta") + bin1.write(">contig1\nATGC\n>contig2\nGCTA") + + bin2 = bin_dir.join("bin2.fasta") + bin2.write(">contig3\nTTAG\n>contig4\nCGAT") + + return bin_dir + +@pytest.fixture +def create_temp_bin_directories(tmpdir, create_temp_bin_files): + # Create temporary bin directories + bin_dir1 = tmpdir.mkdir("set1") + bin1 = bin_dir1.join("bin1.fasta") + bin1.write(">contig1\nATGC\n>contig2\nGCTA") + + bin2 = bin_dir1.join("bin2.fasta") + bin2.write(">contig3\nTTAG\n>contig4\nCGAT") + + + bin_dir2 = tmpdir.mkdir("set2") + bin2 = bin_dir2.join("binA.fasta") + bin2.write(">contig3\nTTAG\n>contig4\nCGAT\n>contig5\nCGGC") + + return {"set1": str(bin_dir1), "set2": str(bin_dir2)} + + +def test_get_bins_from_directory(create_temp_bin_files): + bin_dir = create_temp_bin_files + set_name = "TestSet" + + bins = bin_manager.get_bins_from_directory(str(bin_dir), set_name) + + assert len(bins) == 2 # Ensure that the correct number of Bin objects is returned + + # Check if the Bin objects are created with the correct contigs, set name, and bin names + assert isinstance(bins[0], bin_manager.Bin) + assert isinstance(bins[1], bin_manager.Bin) + assert bins[1].contigs == {"contig1", "contig2"} + assert bins[0].contigs == {"contig3", "contig4"} + assert bins[0].origin == set_name + assert bins[1].origin == set_name + assert bins[1].name == "bin1.fasta" + assert bins[0].name == "bin2.fasta" + +def test_get_bins_from_directory_no_files(tmpdir): + bin_dir = str(tmpdir.mkdir("empty_bins")) + set_name = "EmptySet" + + bins = bin_manager.get_bins_from_directory(bin_dir, set_name) + + assert len(bins) == 0 # Ensure that no Bin objects are returned for an empty directory + + + + +def test_parse_bin_directories(create_temp_bin_directories): + set_name_to_bin_dir = create_temp_bin_directories + + bins = bin_manager.parse_bin_directories(set_name_to_bin_dir) + + assert len(bins) == 2 # Ensure that the correct number of bin directories is parsed + + # Check if the Bin objects are created with the correct contigs, set name, and bin names + assert isinstance(bins["set1"][0], bin_manager.Bin) + assert isinstance(bins["set2"][0], bin_manager.Bin) + + assert len(bins["set2"]) == 1 + assert len(bins["set1"]) == 2 + + diff --git a/tests/bin_quality_test.py b/tests/bin_quality_test.py index 64e3fac..838086a 100644 --- a/tests/bin_quality_test.py +++ b/tests/bin_quality_test.py @@ -1,9 +1,322 @@ -# from . import bin_quality +from itertools import islice from binette import bin_quality +from collections import Counter +import pandas as pd +from unittest.mock import Mock, patch + +from unittest.mock import Mock, patch +from binette.bin_quality import ( + Bin, + add_bin_metrics, + assess_bins_quality_by_chunk, + assess_bins_quality, + chunks, + get_diamond_feature_per_bin_df, + get_bins_metadata_df) + +from checkm2 import keggData, modelPostprocessing, modelProcessing + +from unittest.mock import Mock, patch, MagicMock def test_compute_N50(): assert bin_quality.compute_N50([50]) == 50 assert bin_quality.compute_N50([0]) == 0 + assert bin_quality.compute_N50([]) == 0 assert bin_quality.compute_N50([30, 40, 30]) == 30 assert bin_quality.compute_N50([1, 3, 3, 4, 5, 5, 6, 9, 10, 24]) == 9 + + + +def test_chunks(): + # Test case 1 + iterable_1 = [1, 2, 3, 4, 5, 6] + size_1 = 2 + expected_output_1 = [(1, 2), (3, 4), (5, 6)] + + result_1 = list(chunks(iterable_1, size_1)) + assert result_1 == expected_output_1 + + # Test case 2 + iterable_2 = [10, 20, 30, 40, 50] + size_2 = 3 + expected_output_2 = [(10, 20, 30), (40, 50)] + + result_2 = list(chunks(iterable_2, size_2)) + assert result_2 == expected_output_2 + + # Test case 3 (Empty iterable) + iterable_3 = [] + size_3 = 5 + expected_output_3 = [] + + result_3 = list(chunks(iterable_3, size_3)) + assert result_3 == expected_output_3 + + # Test case 4 (Iterable length less than chunk size) + iterable_4 = [100, 200, 300] + size_4 = 5 + expected_output_4 = [(100, 200, 300)] + + result_4 = list(chunks(iterable_4, size_4)) + assert result_4 == expected_output_4 + + + +class Bin: + def __init__(self, bin_id, contigs): + self.id = bin_id + self.contigs = contigs + self.length = 0 # Mocking the add_length method + self.N50 = 0 # Mocking the add_N50 method + + def add_length(self, length): + self.length = length + + def add_N50(self, N50): + self.N50 = N50 + + def add_N50(self, N50): + self.N50 = N50 + + def add_quality(self, comp, cont, weight): + + self.completeness = comp + self.contamination = cont + self.score = comp - weight * cont + +def test_get_bins_metadata_df(): + # Mock input data + bins = [ + Bin(1, ['contig1', 'contig3']), + Bin(2, ['contig2']) + ] + + contig_to_cds_count = {'contig1': 10, 'contig2': 45, 'contig3': 20, 'contig4': 25} + contig_to_aa_counter = {'contig1': Counter({'A': 5, 'D': 10}), 'contig2': Counter({'G': 8, 'V': 12, 'T': 2}), + 'contig3': Counter({'D': 8, 'Y': 12})} + contig_to_aa_length = {'contig1': 1000, 'contig2': 1500, 'contig3': 2000, 'contig4': 2500} + + # Call the function + result_df = bin_quality.get_bins_metadata_df(bins, contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length) + + # Define expected values based on the provided input + expected_columns = [ + 'Name', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', + 'AALength', 'CDS' + ] + expected_index = [1, 2] + + expected_values = [ + [1, 5, 0, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 3000, 30], + [2, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 1500, 45] + ] + + # Check if the generated DataFrame matches the expected DataFrame + assert result_df.columns.tolist() == expected_columns + assert result_df.index.tolist() == expected_index + assert result_df.values.tolist() == expected_values + + + + +def test_get_diamond_feature_per_bin_df(): + # Mock input data + bins = [ + Bin(1, ['contig1', 'contig2']), + Bin(2, ['contig3', 'contig4']) + ] + + contig_to_kegg_counter = { + 'contig1': Counter({'K01810': 5, 'K15916': 7}), + 'contig2': Counter({'K01810': 10}), + 'contig3': Counter({'K00918': 8}), + } + + # Call the function + result_df, default_ko_count = bin_quality.get_diamond_feature_per_bin_df(bins, contig_to_kegg_counter) + + expected_index = [1, 2] + + + assert result_df.index.tolist() == expected_index + assert result_df.loc[1,"K01810"] == 15 # in bin1 from contig 1 and 2 + assert result_df.loc[1,"K15916"] == 7 # in bin1 from contig 1 + assert result_df.loc[2,"K01810"] == 0 # this ko is not in any contig of bin 2 + assert result_df.loc[2,"K00918"] == 8 # in bin2 from contig 3 + + + +def test_add_bin_size_and_N50(): + # Mock input data + bins = [ + Bin(1, ['contig1', 'contig2']), + Bin(2, ['contig3']) + ] + + contig_to_size = { + 'contig1': 1000, + 'contig2': 1500, + 'contig3': 2000, + } + + # Call the function + bin_quality.add_bin_size_and_N50(bins, contig_to_size) + + # Assertions to verify if add_length and add_N50 were called with the correct values + assert bins[0].length == 2500 + assert bins[0].N50 == 1500 + assert bins[1].length == 2000 + assert bins[1].N50 == 2000 + + + + +def mock_modelProcessor(thread): + return "mock_modelProcessor" + +def test_add_bin_metrics(monkeypatch): + # Mock input data + bins = [ + Bin(1, ['contig1', 'contig2']), + Bin(2, ['contig3']) + ] + + contig_info = { + # Add mocked contig information here as needed + "contig_to_kegg_counter":{}, + "contig_to_cds_count":{}, + "contig_to_aa_counter":{}, + "contig_to_aa_length":{}, + "contig_to_length":{}, + } + + contamination_weight = 0.5 + threads = 1 + + + monkeypatch.setattr(modelPostprocessing, "modelProcessor", mock_modelProcessor) + + + # Mock the functions called within add_bin_metrics + with patch('binette.bin_quality.add_bin_size_and_N50') as mock_add_bin_size_and_N50, \ + patch('binette.bin_quality.assess_bins_quality_by_chunk') as mock_assess_bins_quality_by_chunk: + + add_bin_metrics(bins, contig_info, contamination_weight, threads) + + # Assertions to check if functions were called with the expected arguments + mock_add_bin_size_and_N50.assert_called_once_with(bins, contig_info["contig_to_length"]) + mock_assess_bins_quality_by_chunk.assert_called_once_with( + bins, + contig_info["contig_to_kegg_counter"], + contig_info["contig_to_cds_count"], + contig_info["contig_to_aa_counter"], + contig_info["contig_to_aa_length"], + contamination_weight, + "mock_modelProcessor", # Mocked postProcessor object + ) + +def test_assess_bins_quality_by_chunk(monkeypatch): + # Prepare input data for testing + bins = [ + Bin(1, ['contig1', 'contig2']), + Bin(2, ['contig3', 'contig4']), + Bin(3, ['contig3', 'contig4']) + ] + + contig_to_kegg_counter = {} + contig_to_cds_count = {} + contig_to_aa_counter = {} + contig_to_aa_length = {} + contamination_weight = 0.5 + + # Mocking postProcessor object + + monkeypatch.setattr(modelPostprocessing, "modelProcessor", mock_modelProcessor) + + + # Mock the functions called within add_bin_metrics + with patch('binette.bin_quality.assess_bins_quality') as mock_assess_bins_quality: + + assess_bins_quality_by_chunk( + bins, + contig_to_kegg_counter, + contig_to_cds_count, + contig_to_aa_counter, + contig_to_aa_length, + contamination_weight, + postProcessor=None, + threads=1, + chunk_size = 3 + ) + + # Chunk size > number of bin so only one chunk + mock_assess_bins_quality.assert_called_once_with( + bins= set(bins), + contig_to_kegg_counter= contig_to_kegg_counter, + contig_to_cds_count=contig_to_cds_count, + contig_to_aa_counter=contig_to_aa_counter, + contig_to_aa_length=contig_to_aa_length, + contamination_weight=contamination_weight, + postProcessor=None, + threads=1 + ) + + # Mock the functions called within add_bin_metrics + with patch('binette.bin_quality.assess_bins_quality') as mock_assess_bins_quality: + + assess_bins_quality_by_chunk( + bins, + contig_to_kegg_counter, + contig_to_cds_count, + contig_to_aa_counter, + contig_to_aa_length, + contamination_weight, + postProcessor=None, + threads=1, + chunk_size = 2 + ) + + # Chunk size < number of bin so 2 chunks with [bin1,bin2] and [bin3] + assert mock_assess_bins_quality.call_count == 2 + + + +from unittest.mock import patch, MagicMock +import numpy as np +import pandas as pd +from checkm2 import keggData, modelPostprocessing, modelProcessing + + +def test_assess_bins_quality(): + # Prepare mock input data for testing + bins = [ + Bin(1, ['contig1', 'contig2']), + Bin(2, ['contig3', 'contig4']) + ] + + contig_to_kegg_counter = {} + contig_to_cds_count = {} + contig_to_aa_length = {} + contig_to_aa_counter = {} + contamination_weight = 0.5 + + + # Call the function being tested + assess_bins_quality( + bins, + contig_to_kegg_counter, + contig_to_cds_count, + contig_to_aa_counter, + contig_to_aa_length, + contamination_weight + ) + + + + # Verify the expected calls to add_quality for each bin object + for bin_obj in bins: + assert bin_obj.completeness is not None + assert bin_obj.contamination is not None + assert bin_obj.score is not None + assert bin_obj.score == bin_obj.completeness - bin_obj.contamination * contamination_weight \ No newline at end of file diff --git a/tests/cds_test.py b/tests/cds_test.py index d2c4c37..92db12c 100644 --- a/tests/cds_test.py +++ b/tests/cds_test.py @@ -171,8 +171,8 @@ def test_get_contig_cds_metadata(): contig_to_genes = {"c1":["AAAA", "GGGG", "CCCC"], "c2":["TTTT", "CCCC"]} - contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length = cds.get_contig_cds_metadata(contig_to_genes, 1) + contig_metadata = cds.get_contig_cds_metadata(contig_to_genes, 1) - assert contig_to_cds_count == {"c1":3, "c2":2} - assert contig_to_aa_counter == {"c1": {'A': 4, 'G': 4, "C":4} , "c2":{'C': 4, 'T': 4}} - assert contig_to_aa_length == {"c1":12, "c2":8} \ No newline at end of file + assert contig_metadata['contig_to_cds_count'] == {"c1":3, "c2":2} + assert contig_metadata['contig_to_aa_counter'] == {"c1": {'A': 4, 'G': 4, "C":4} , "c2":{'C': 4, 'T': 4}} + assert contig_metadata['contig_to_aa_length'] == {"c1":12, "c2":8} \ No newline at end of file diff --git a/tests/diamonds_test.py b/tests/diamonds_test.py index 9749b4e..a2d23a3 100644 --- a/tests/diamonds_test.py +++ b/tests/diamonds_test.py @@ -13,6 +13,9 @@ from binette import diamond +import pandas as pd +from collections import Counter + class CompletedProcess: def __init__(self, returncode, stderr): self.returncode = returncode @@ -123,9 +126,6 @@ def mock_shutil_which(*args, **kwargs): def test_run_diamond_tool_found(monkeypatch): - # Mocking check_tool_exists - def mock_check_tool_exists(*args, **kwargs): - pass monkeypatch.setattr(sys, "exit", lambda x: None) # Patch sys.exit to avoid test interruption @@ -166,12 +166,6 @@ def mock_check_tool_exists(*args, **kwargs): mock_exit.assert_called_once_with(1) -import pandas as pd -from collections import Counter - -# Import the function get_contig_to_kegg_id here - - def test_get_contig_to_kegg_id(): # Mock input data diamond_result_file = "dummy_diamond_results.txt" diff --git a/tests/io_manager_test.py b/tests/io_manager_test.py index 7d38430..01909c2 100644 --- a/tests/io_manager_test.py +++ b/tests/io_manager_test.py @@ -139,3 +139,35 @@ def test_check_contig_consistency_no_error(): io_manager.check_contig_consistency( contigs_from_assembly, contigs_from_bins, assembly_file, elsewhere_file ) + +@pytest.fixture +def temp_files(tmp_path): + # Create temporary files for testing + faa_file = tmp_path / "test_protein.faa" + diamond_result_file = tmp_path / "test_diamond_result.txt" + faa_file.touch() + diamond_result_file.touch() + yield str(faa_file), str(diamond_result_file) + +def test_check_resume_file_exists(temp_files, caplog): + # Test when both files exist + faa_file, diamond_result_file = temp_files + io_manager.check_resume_file(faa_file, diamond_result_file) + assert "Protein file" not in caplog.text + assert "Diamond result file" not in caplog.text + +def test_check_resume_file_missing_faa(temp_files, caplog): + # Test when faa_file is missing + _, diamond_result_file = temp_files + with pytest.raises(FileNotFoundError): + io_manager.check_resume_file("nonexistent.faa", diamond_result_file) + assert "Protein file" in caplog.text + assert "Diamond result file" not in caplog.text + +def test_check_resume_file_missing_diamond(temp_files, caplog): + # Test when diamond_result_file is missing + faa_file, _ = temp_files + with pytest.raises(FileNotFoundError): + io_manager.check_resume_file(faa_file, "nonexistent_diamond_result.txt") + assert "Protein file" not in caplog.text + assert "Diamond result file" in caplog.text diff --git a/tests/main_binette_test.py b/tests/main_binette_test.py new file mode 100644 index 0000000..9af558e --- /dev/null +++ b/tests/main_binette_test.py @@ -0,0 +1,319 @@ + +import pytest +import logging +from binette.binette import log_selected_bin_info, select_bins_and_write_them, manage_protein_alignement, parse_input_files, parse_arguments, init_logging, main +from binette.bin_manager import Bin +from binette import diamond +import os +import sys +from unittest.mock import patch + +from collections import Counter +from bin_manager_test import create_temp_bin_directories, create_temp_bin_files + +@pytest.fixture +def bins(): + b1 = Bin(contigs={"contig1"}, origin="set1", name="bin1") + b2 = Bin(contigs={"contig3"}, origin="set1", name="bin2") + b3 = Bin(contigs={"contig3", "contig2"}, origin="set1", name="bin3") + + b1.add_quality(100, 0, 0) + b2.add_quality(95, 10, 0) + b3.add_quality(70, 20, 0) + + return [b1, b2, b3] + +def test_log_selected_bin_info(caplog, bins): + + caplog.set_level(logging.INFO) + + + hq_min_completeness = 85 + hq_max_conta = 15 + + # Call the function + log_selected_bin_info(bins, hq_min_completeness, hq_max_conta) + + # Check if the logs contain expected messages + expected_logs ="2/3 selected bins have a high quality (completeness >= 85 and contamination <= 15)." + + + assert expected_logs in caplog.text + + +def test_select_bins_and_write_them(tmp_path, tmpdir, bins): + # Create temporary directories and files for testing + outdir = tmpdir.mkdir("test_outdir") + contigs_fasta = os.path.join(str(outdir), "contigs.fasta") + final_bin_report = os.path.join(str(outdir), "final_bin_report.tsv") + + index_to_contig={"contig1":"contig1", "contig2": "contig2", "contig3":"contig3"} + + contigs_fasta = tmp_path / "contigs.fasta" + contigs_fasta_content = ( + ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n" + ) + contigs_fasta.write_text(contigs_fasta_content) + + b1, b2, b3 = bins + + + # Run the function with test data + selected_bins = select_bins_and_write_them( + set(bins), str(contigs_fasta), final_bin_report, min_completeness=60, index_to_contig=index_to_contig, outdir=str(outdir), debug=False + ) + + # Assertions to check the function output or file existence + assert isinstance(selected_bins, list) + assert os.path.isfile(final_bin_report) + assert selected_bins == bins[:2] # The third bin is overlapping with the second one and has a worse score so it is not selected. + + with open(outdir / f"final_bins/bin_{b1.id}.fa", "r") as bin1_file: + assert bin1_file.read() == ">contig1\nACGT\n" + + with open(outdir / f"final_bins/bin_{b2.id}.fa", "r") as bin2_file: + assert bin2_file.read() == ">contig3\nAAAA\n" + + assert not os.path.isfile(outdir / f"final_bins/bin_{b3.id}.fa") + + + +def test_manage_protein_alignement_resume(tmp_path): + # Create temporary directories and files for testing + + faa_file = tmp_path / "proteins.faa" + faa_file_content = ( + ">contig1_1\nACGT\n>contig2_1\nTGCA\n>contig2_2\nAAAA\n>contig3_1\nCCCC\n" + ) + + contig_to_length={"contig1":40, "contig2":80, "contig3":20} + + faa_file.write_text(faa_file_content) + + contig_to_kegg_id = { + "contig1": Counter({"K12345": 1, "K67890": 1}), + "contig2": Counter({"K23456": 1}) + } + + + with patch("binette.diamond.get_contig_to_kegg_id", return_value=contig_to_kegg_id): + + # Call the function + + # Run the function with test data + contig_to_kegg_counter, contig_to_genes = manage_protein_alignement( + faa_file=str(faa_file), + contigs_fasta="contigs_fasta", + contig_to_length=contig_to_length, + contigs_in_bins={}, + diamond_result_file="diamond_result_file", + checkm2_db=None, + threads=1, + resume=True, + low_mem=False + ) + + # Assertions to check the function output or file existence + assert isinstance(contig_to_genes, dict) + assert isinstance(contig_to_kegg_counter, dict) + assert len(contig_to_genes) == 3 + + +def test_manage_protein_alignement_not_resume(tmpdir, tmp_path): + # Create temporary directories and files for testing + + faa_file = tmp_path / "proteins.faa" + faa_file_content = ( + ">contig1_1\nACGT\n>contig2_1\nTGCA\n>contig2_2\nAAAA\n>contig3_1\nCCCC\n" + ) + + contig_to_length={"contig1":40, "contig2":80, "contig3":20} + + faa_file.write_text(faa_file_content) + + + contigs_fasta = os.path.join(str(tmpdir), "contigs.fasta") + diamond_result_file = os.path.join(str(tmpdir), "diamond_results.tsv") + + contig_to_kegg_id = { + "contig1": Counter({"K12345": 1, "K67890": 1}), + "contig2": Counter({"K23456": 1}) + } + + + with patch("binette.diamond.get_contig_to_kegg_id", return_value=contig_to_kegg_id), \ + patch("binette.diamond.run", return_value=None): + + # Call the function + + contig_to_kegg_counter, contig_to_genes = manage_protein_alignement( + faa_file=str(faa_file), + contigs_fasta=contigs_fasta, + contig_to_length=contig_to_length, + contigs_in_bins={}, + diamond_result_file=diamond_result_file, + checkm2_db=None, + threads=1, + resume=True, + low_mem=False + ) + + # Assertions to check the function output or file existence + assert isinstance(contig_to_genes, dict) + assert isinstance(contig_to_kegg_counter, dict) + assert len(contig_to_genes) == 3 + + +def test_parse_input_files_bin_dirs(create_temp_bin_directories, tmp_path): + + set_name_to_bin_dir = create_temp_bin_directories + bin_dirs = list(create_temp_bin_directories.values()) + + contig2bin_tables = [] + + # Create temporary directories and files for testing + + fasta_file = tmp_path / "assembly.fasta" + fasta_file_content = ( + ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" + ) + fasta_file.write_text(fasta_file_content) + + # Call the function and capture the return values + bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = parse_input_files(bin_dirs, contig2bin_tables, str(fasta_file)) + + + # # Perform assertions on the returned values + assert isinstance(bin_set_name_to_bins, dict) + assert isinstance(original_bins, set) + assert isinstance(contigs_in_bins, set) + assert isinstance(contig_to_length, dict) + + + assert set(bin_set_name_to_bins) == {'1', "2"} + assert len(original_bins) == 3 + assert contigs_in_bins == {"contig1","contig2", "contig3","contig4","contig5",} + assert len(contig_to_length) == 5 + + +def test_parse_input_files_bin_dirs(create_temp_bin_directories, tmp_path): + + set_name_to_bin_dir = create_temp_bin_directories + bin_dirs = list(create_temp_bin_directories.values()) + + contig2bin_tables = [] + + # Create temporary directories and files for testing + + fasta_file = tmp_path / "assembly.fasta" + fasta_file_content = ( + ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" + ) + fasta_file.write_text(fasta_file_content) + + # Call the function and capture the return values + bin_set_name_to_bins, original_bins, contigs_in_bins, contig_to_length = parse_input_files(bin_dirs, contig2bin_tables, str(fasta_file)) + + # # Perform assertions on the returned values + assert isinstance(bin_set_name_to_bins, dict) + assert isinstance(original_bins, set) + assert isinstance(contigs_in_bins, set) + assert isinstance(contig_to_length, dict) + + + assert set(bin_set_name_to_bins) == {'1', "2"} + assert len(original_bins) == 3 + assert contigs_in_bins == {"contig1","contig2", "contig3","contig4","contig5",} + assert len(contig_to_length) == 5 + + + +def test_parse_arguments_required_arguments(): + # Test when only required arguments are provided + args = parse_arguments(["-d", "folder1", "folder2", "-c", "contigs.fasta"]) + assert args.bin_dirs == ["folder1", "folder2"] + assert args.contigs == "contigs.fasta" + +def test_parse_arguments_optional_arguments(): + # Test when required and optional arguments are provided + args = parse_arguments(["-d", "folder1", "folder2", "-c", "contigs.fasta", "--threads", "4", "--outdir", "output"]) + assert args.bin_dirs == ["folder1", "folder2"] + assert args.contigs == "contigs.fasta" + assert args.threads == 4 + assert args.outdir == "output" + +def test_parse_arguments_invalid_arguments(): + # Test when invalid arguments are provided + with pytest.raises(SystemExit): + # In this case, required arguments are missing + parse_arguments(["-t", "4"]) + +def test_parse_arguments_help(): + # Test the help message + with pytest.raises(SystemExit) as pytest_wrapped_e: + parse_arguments(["-h"]) + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + + +def test_init_logging_command_line(caplog): + + caplog.set_level(logging.INFO) + + init_logging(verbose=True, debug=False) + expected_log_message = f'command line: {" ".join(sys.argv)}' + # Check if the log message is present in the log records + + assert expected_log_message in caplog.text + + + +def test_main_functionality_resume_when_not_possible(monkeypatch): + # Define or mock the necessary inputs/arguments + + # Mock sys.argv to use test_args + test_args = [ + "-d", "bin_dir1", "bin_dir2", + "-c", "contigs.fasta", + # ... more arguments as required ... + "--debug", + "--resume" + ] + monkeypatch.setattr(sys, 'argv', ['your_script.py'] + test_args) + + # You may also need to mock certain functions to avoid actual file operations or to simulate their behavior + # For example, mock the functions parse_input_files, manage_protein_alignement, select_bins_and_write_them, etc. + + # Call the main function + with pytest.raises(FileNotFoundError) as e_info: + main() + + +# def test_main_functionality_(monkeypatch, create_temp_bin_directories, tmp_path): +# # Define or mock the necessary inputs/arguments + +# bin_dirs = list(create_temp_bin_directories.values()) +# fasta_file = tmp_path / "assembly.fasta" +# fasta_file_content = ( +# ">contig1\nACGT\n>contig2\nTGCA\n>contig3\nAAAA\n>contig4\nCCCC\n>contig5\nCGTCGCT\n" +# ) +# fasta_file.write_text(fasta_file_content) + +# # Mock sys.argv to use test_args +# test_args = [ +# "-d"] + bin_dirs + ["-c", str(fasta_file), +# # ... more arguments as required ... +# "--debug" +# ] + +# monkeypatch.setattr(sys, 'argv', ['your_script.py'] + test_args) + +# # You may also need to mock certain functions to avoid actual file operations or to simulate their behavior +# # For example, mock the functions parse_input_files, manage_protein_alignement, select_bins_and_write_them, etc. + +# # Call the main function +# with patch("binette.diamond.get_contig_to_kegg_id", return_value=contig_to_kegg_id), \ +# patch("binette.diamond.run", return_value=None): +# result = main() +# assert result == 0 + \ No newline at end of file